Upload convert_to_onnx.py with huggingface_hub
Browse files- convert_to_onnx.py +148 -17
convert_to_onnx.py
CHANGED
|
@@ -18,11 +18,6 @@ import torch
|
|
| 18 |
import torch.nn as nn
|
| 19 |
import yaml
|
| 20 |
|
| 21 |
-
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'msst'))
|
| 22 |
-
|
| 23 |
-
from models.bs_roformer.bs_roformer import BSRoformer
|
| 24 |
-
|
| 25 |
-
|
| 26 |
# ---------------------------------------------------------------------------
|
| 27 |
# Wrapper that exposes only the exportable "core" of BSRoformer
|
| 28 |
# ---------------------------------------------------------------------------
|
|
@@ -34,7 +29,7 @@ class BSRoformerCore(nn.Module):
|
|
| 34 |
Output: mask – (B, num_stems, F, T, 2) complex mask as real tensor
|
| 35 |
"""
|
| 36 |
|
| 37 |
-
def __init__(self, model:
|
| 38 |
super().__init__()
|
| 39 |
self.band_split = model.band_split
|
| 40 |
self.layers = model.layers
|
|
@@ -93,7 +88,7 @@ class BSRoformerCore(nn.Module):
|
|
| 93 |
# Helper: replace PoPE attention with standard attention for ONNX export
|
| 94 |
# ---------------------------------------------------------------------------
|
| 95 |
|
| 96 |
-
def replace_pope_with_standard_attn(model:
|
| 97 |
"""Replace PoPE-based flash attention with equivalent standard attention.
|
| 98 |
|
| 99 |
PoPE attention applies: q, k = softplus(q), softplus(k) then rotates by
|
|
@@ -209,6 +204,11 @@ def replace_pope_with_standard_attn(model: BSRoformer):
|
|
| 209 |
|
| 210 |
def load_model(config_path, checkpoint_path):
|
| 211 |
"""Load BS PolarFormer model from config and checkpoint."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
with open(config_path, 'r') as f:
|
| 213 |
config = yaml.full_load(f)
|
| 214 |
|
|
@@ -503,6 +503,7 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 503 |
"""
|
| 504 |
import onnx
|
| 505 |
from onnx import helper, TensorProto, numpy_helper
|
|
|
|
| 506 |
|
| 507 |
print(f"\n=== Cascading large Split/Concat ops (max {max_bindings} bindings) ===")
|
| 508 |
model = onnx.load(onnx_path)
|
|
@@ -514,27 +515,38 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 514 |
type_map[vi.name] = vi
|
| 515 |
|
| 516 |
def _get_shape(tensor_name):
|
| 517 |
-
"""Get
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
vi = type_map.get(tensor_name)
|
| 519 |
if vi is None:
|
| 520 |
return None
|
| 521 |
try:
|
| 522 |
dims = vi.type.tensor_type.shape.dim
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
except Exception:
|
| 525 |
return None
|
| 526 |
|
| 527 |
def _make_value_info(name, shape_dims, elem_type=1):
|
| 528 |
"""Create a TensorValueInfoProto with the given shape dimensions.
|
| 529 |
-
shape_dims is a list of (dim_value, dim_param) tuples.
|
| 530 |
elem_type=1 means FLOAT."""
|
| 531 |
vi = helper.make_tensor_value_info(name, elem_type, [None] * len(shape_dims))
|
| 532 |
-
for i, (dv, dp) in enumerate(shape_dims):
|
| 533 |
dim = vi.type.tensor_type.shape.dim[i]
|
| 534 |
dim.Clear()
|
| 535 |
if dp:
|
| 536 |
dim.dim_param = dp
|
| 537 |
-
|
| 538 |
dim.dim_value = dv
|
| 539 |
return vi
|
| 540 |
|
|
@@ -559,7 +571,7 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 559 |
for oi, oname in enumerate(output_names):
|
| 560 |
if oname not in type_map:
|
| 561 |
shape = list(input_shape)
|
| 562 |
-
shape[a] = (sizes[oi], '')
|
| 563 |
vi = _make_value_info(oname, shape)
|
| 564 |
new_value_infos.append(vi)
|
| 565 |
type_map[oname] = vi
|
|
@@ -592,7 +604,7 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 592 |
shape = list(input_shape)
|
| 593 |
# Replace the split axis dimension with the group's concrete size
|
| 594 |
a = axis if axis >= 0 else len(shape) + axis
|
| 595 |
-
shape[a] = (group_sizes[gi], '') # concrete dim_value, no dim_param
|
| 596 |
vi = _make_value_info(gout, shape)
|
| 597 |
new_value_infos.append(vi)
|
| 598 |
type_map[gout] = vi
|
|
@@ -658,6 +670,28 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 658 |
|
| 659 |
inputs = list(node.input)
|
| 660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
# Group inputs into chunks of max_bindings
|
| 662 |
groups = []
|
| 663 |
for i in range(0, n_in, max_bindings):
|
|
@@ -673,13 +707,13 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 673 |
total = 0
|
| 674 |
all_concrete = True
|
| 675 |
for s in shapes:
|
| 676 |
-
dv, dp = s[a]
|
| 677 |
-
if dp or
|
| 678 |
all_concrete = False
|
| 679 |
break
|
| 680 |
total += dv
|
| 681 |
if all_concrete:
|
| 682 |
-
base[a] = (total, '')
|
| 683 |
return base
|
| 684 |
return None
|
| 685 |
|
|
@@ -745,6 +779,50 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 745 |
graph.node.remove(node)
|
| 746 |
graph.node.extend(nodes_to_add)
|
| 747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
# Add explicit value_info entries for cascade intermediates (as hints).
|
| 749 |
for vi in new_value_infos:
|
| 750 |
graph.value_info.append(vi)
|
|
@@ -800,6 +878,59 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
|
|
| 800 |
else:
|
| 801 |
print(f" All intermediate tensors have proper shape annotations")
|
| 802 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
return out_path
|
| 804 |
|
| 805 |
|
|
|
|
| 18 |
import torch.nn as nn
|
| 19 |
import yaml
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ---------------------------------------------------------------------------
|
| 22 |
# Wrapper that exposes only the exportable "core" of BSRoformer
|
| 23 |
# ---------------------------------------------------------------------------
|
|
|
|
| 29 |
Output: mask – (B, num_stems, F, T, 2) complex mask as real tensor
|
| 30 |
"""
|
| 31 |
|
| 32 |
+
def __init__(self, model: nn.Module):
|
| 33 |
super().__init__()
|
| 34 |
self.band_split = model.band_split
|
| 35 |
self.layers = model.layers
|
|
|
|
| 88 |
# Helper: replace PoPE attention with standard attention for ONNX export
|
| 89 |
# ---------------------------------------------------------------------------
|
| 90 |
|
| 91 |
+
def replace_pope_with_standard_attn(model: nn.Module):
|
| 92 |
"""Replace PoPE-based flash attention with equivalent standard attention.
|
| 93 |
|
| 94 |
PoPE attention applies: q, k = softplus(q), softplus(k) then rotates by
|
|
|
|
| 204 |
|
| 205 |
def load_model(config_path, checkpoint_path):
|
| 206 |
"""Load BS PolarFormer model from config and checkpoint."""
|
| 207 |
+
msst_root = os.path.join(os.path.dirname(__file__), '..', 'msst')
|
| 208 |
+
if msst_root not in sys.path:
|
| 209 |
+
sys.path.insert(0, msst_root)
|
| 210 |
+
from models.bs_roformer.bs_roformer import BSRoformer
|
| 211 |
+
|
| 212 |
with open(config_path, 'r') as f:
|
| 213 |
config = yaml.full_load(f)
|
| 214 |
|
|
|
|
| 503 |
"""
|
| 504 |
import onnx
|
| 505 |
from onnx import helper, TensorProto, numpy_helper
|
| 506 |
+
from collections import defaultdict, deque
|
| 507 |
|
| 508 |
print(f"\n=== Cascading large Split/Concat ops (max {max_bindings} bindings) ===")
|
| 509 |
model = onnx.load(onnx_path)
|
|
|
|
| 515 |
type_map[vi.name] = vi
|
| 516 |
|
| 517 |
def _get_shape(tensor_name):
|
| 518 |
+
"""Get shape dims as (dim_value, dim_param, has_dim_value).
|
| 519 |
+
|
| 520 |
+
Important: in ONNX protobuf, dim_value defaults to 0 when unset.
|
| 521 |
+
We must track whether dim_value is explicitly set, otherwise unknown
|
| 522 |
+
dimensions can be misread as concrete zeros.
|
| 523 |
+
"""
|
| 524 |
vi = type_map.get(tensor_name)
|
| 525 |
if vi is None:
|
| 526 |
return None
|
| 527 |
try:
|
| 528 |
dims = vi.type.tensor_type.shape.dim
|
| 529 |
+
shape = []
|
| 530 |
+
for d in dims:
|
| 531 |
+
has_dim_value = d.HasField('dim_value')
|
| 532 |
+
dv = d.dim_value if has_dim_value else None
|
| 533 |
+
dp = d.dim_param if d.HasField('dim_param') else ''
|
| 534 |
+
shape.append((dv, dp, has_dim_value))
|
| 535 |
+
return shape
|
| 536 |
except Exception:
|
| 537 |
return None
|
| 538 |
|
| 539 |
def _make_value_info(name, shape_dims, elem_type=1):
|
| 540 |
"""Create a TensorValueInfoProto with the given shape dimensions.
|
| 541 |
+
shape_dims is a list of (dim_value, dim_param, has_dim_value) tuples.
|
| 542 |
elem_type=1 means FLOAT."""
|
| 543 |
vi = helper.make_tensor_value_info(name, elem_type, [None] * len(shape_dims))
|
| 544 |
+
for i, (dv, dp, has_dim_value) in enumerate(shape_dims):
|
| 545 |
dim = vi.type.tensor_type.shape.dim[i]
|
| 546 |
dim.Clear()
|
| 547 |
if dp:
|
| 548 |
dim.dim_param = dp
|
| 549 |
+
elif has_dim_value and dv is not None:
|
| 550 |
dim.dim_value = dv
|
| 551 |
return vi
|
| 552 |
|
|
|
|
| 571 |
for oi, oname in enumerate(output_names):
|
| 572 |
if oname not in type_map:
|
| 573 |
shape = list(input_shape)
|
| 574 |
+
shape[a] = (sizes[oi], '', True)
|
| 575 |
vi = _make_value_info(oname, shape)
|
| 576 |
new_value_infos.append(vi)
|
| 577 |
type_map[oname] = vi
|
|
|
|
| 604 |
shape = list(input_shape)
|
| 605 |
# Replace the split axis dimension with the group's concrete size
|
| 606 |
a = axis if axis >= 0 else len(shape) + axis
|
| 607 |
+
shape[a] = (group_sizes[gi], '', True) # concrete dim_value, no dim_param
|
| 608 |
vi = _make_value_info(gout, shape)
|
| 609 |
new_value_infos.append(vi)
|
| 610 |
type_map[gout] = vi
|
|
|
|
| 670 |
|
| 671 |
inputs = list(node.input)
|
| 672 |
|
| 673 |
+
# Filter out 0-sized inputs to avoid creating Concat nodes with ONLY 0-sized inputs
|
| 674 |
+
# (which crashes WebGPU's WGSL compiler).
|
| 675 |
+
valid_inputs = []
|
| 676 |
+
for inp in inputs:
|
| 677 |
+
shape = _get_shape(inp)
|
| 678 |
+
is_zero = False
|
| 679 |
+
if shape:
|
| 680 |
+
a = axis if axis >= 0 else len(shape) + axis
|
| 681 |
+
if a < len(shape):
|
| 682 |
+
dv, dp, has_dv = shape[a]
|
| 683 |
+
if has_dv and dv == 0 and not dp:
|
| 684 |
+
is_zero = True
|
| 685 |
+
if not is_zero:
|
| 686 |
+
valid_inputs.append(inp)
|
| 687 |
+
|
| 688 |
+
# If all inputs were 0-sized, keep at least one so the Concat node is valid
|
| 689 |
+
if not valid_inputs and inputs:
|
| 690 |
+
valid_inputs = [inputs[0]]
|
| 691 |
+
|
| 692 |
+
inputs = valid_inputs
|
| 693 |
+
n_in = len(inputs)
|
| 694 |
+
|
| 695 |
# Group inputs into chunks of max_bindings
|
| 696 |
groups = []
|
| 697 |
for i in range(0, n_in, max_bindings):
|
|
|
|
| 707 |
total = 0
|
| 708 |
all_concrete = True
|
| 709 |
for s in shapes:
|
| 710 |
+
dv, dp, has_dv = s[a]
|
| 711 |
+
if dp or not has_dv:
|
| 712 |
all_concrete = False
|
| 713 |
break
|
| 714 |
total += dv
|
| 715 |
if all_concrete:
|
| 716 |
+
base[a] = (total, '', True)
|
| 717 |
return base
|
| 718 |
return None
|
| 719 |
|
|
|
|
| 779 |
graph.node.remove(node)
|
| 780 |
graph.node.extend(nodes_to_add)
|
| 781 |
|
| 782 |
+
# Re-establish topological ordering after graph rewrites.
|
| 783 |
+
# Appending new nodes at the end can violate node dependency order.
|
| 784 |
+
def _toposort_nodes(g):
|
| 785 |
+
producers = {}
|
| 786 |
+
for i, node in enumerate(g.node):
|
| 787 |
+
for out in node.output:
|
| 788 |
+
if out:
|
| 789 |
+
producers[out] = i
|
| 790 |
+
|
| 791 |
+
deps = defaultdict(set)
|
| 792 |
+
users = defaultdict(set)
|
| 793 |
+
indegree = [0] * len(g.node)
|
| 794 |
+
|
| 795 |
+
for i, node in enumerate(g.node):
|
| 796 |
+
for inp in node.input:
|
| 797 |
+
if not inp:
|
| 798 |
+
continue
|
| 799 |
+
p = producers.get(inp)
|
| 800 |
+
if p is None or p == i:
|
| 801 |
+
continue
|
| 802 |
+
if p not in deps[i]:
|
| 803 |
+
deps[i].add(p)
|
| 804 |
+
users[p].add(i)
|
| 805 |
+
indegree[i] += 1
|
| 806 |
+
|
| 807 |
+
q = deque([i for i, d in enumerate(indegree) if d == 0])
|
| 808 |
+
order = []
|
| 809 |
+
while q:
|
| 810 |
+
cur = q.popleft()
|
| 811 |
+
order.append(cur)
|
| 812 |
+
for nxt in users[cur]:
|
| 813 |
+
indegree[nxt] -= 1
|
| 814 |
+
if indegree[nxt] == 0:
|
| 815 |
+
q.append(nxt)
|
| 816 |
+
|
| 817 |
+
if len(order) != len(g.node):
|
| 818 |
+
raise RuntimeError("Failed to topologically sort rewritten ONNX graph (cycle or missing dependency).")
|
| 819 |
+
|
| 820 |
+
sorted_nodes = [g.node[i] for i in order]
|
| 821 |
+
del g.node[:]
|
| 822 |
+
g.node.extend(sorted_nodes)
|
| 823 |
+
|
| 824 |
+
_toposort_nodes(graph)
|
| 825 |
+
|
| 826 |
# Add explicit value_info entries for cascade intermediates (as hints).
|
| 827 |
for vi in new_value_infos:
|
| 828 |
graph.value_info.append(vi)
|
|
|
|
| 878 |
else:
|
| 879 |
print(f" All intermediate tensors have proper shape annotations")
|
| 880 |
|
| 881 |
+
# Extra safety check: reject models where a Concat would end up with only
|
| 882 |
+
# statically zero-sized inputs on its concat axis (known to break WebGPU WGSL).
|
| 883 |
+
vi_map2 = {}
|
| 884 |
+
for vi in list(model2.graph.value_info) + list(model2.graph.input) + list(model2.graph.output):
|
| 885 |
+
vi_map2[vi.name] = vi
|
| 886 |
+
|
| 887 |
+
def _shape_from_vi(vi):
|
| 888 |
+
try:
|
| 889 |
+
dims = vi.type.tensor_type.shape.dim
|
| 890 |
+
out = []
|
| 891 |
+
for d in dims:
|
| 892 |
+
has_dv = d.HasField('dim_value')
|
| 893 |
+
dv = d.dim_value if has_dv else None
|
| 894 |
+
dp = d.dim_param if d.HasField('dim_param') else ''
|
| 895 |
+
out.append((dv, dp, has_dv))
|
| 896 |
+
return out
|
| 897 |
+
except Exception:
|
| 898 |
+
return None
|
| 899 |
+
|
| 900 |
+
bad_concat = 0
|
| 901 |
+
for node in model2.graph.node:
|
| 902 |
+
if node.op_type != 'Concat':
|
| 903 |
+
continue
|
| 904 |
+
axis = 0
|
| 905 |
+
for attr in node.attribute:
|
| 906 |
+
if attr.name == 'axis':
|
| 907 |
+
axis = attr.i
|
| 908 |
+
all_zero = True
|
| 909 |
+
for inp in node.input:
|
| 910 |
+
vi = vi_map2.get(inp)
|
| 911 |
+
shape = _shape_from_vi(vi) if vi is not None else None
|
| 912 |
+
if not shape:
|
| 913 |
+
all_zero = False
|
| 914 |
+
continue
|
| 915 |
+
a = axis if axis >= 0 else len(shape) + axis
|
| 916 |
+
if a < 0 or a >= len(shape):
|
| 917 |
+
all_zero = False
|
| 918 |
+
continue
|
| 919 |
+
dv, dp, has_dv = shape[a]
|
| 920 |
+
if dp or not has_dv:
|
| 921 |
+
all_zero = False
|
| 922 |
+
elif dv > 0:
|
| 923 |
+
all_zero = False
|
| 924 |
+
if all_zero or not node.input:
|
| 925 |
+
bad_concat += 1
|
| 926 |
+
print(f" ERROR: problematic Concat '{node.name}' may produce invalid WebGPU shader")
|
| 927 |
+
|
| 928 |
+
if bad_concat:
|
| 929 |
+
raise RuntimeError(
|
| 930 |
+
f"Detected {bad_concat} problematic Concat nodes after cascading; "
|
| 931 |
+
"refuse to output potentially invalid WebGPU model."
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
return out_path
|
| 935 |
|
| 936 |
|