bgkb commited on
Commit
1a36bca
·
verified ·
1 Parent(s): 614295a

Upload convert_to_onnx.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_to_onnx.py +148 -17
convert_to_onnx.py CHANGED
@@ -18,11 +18,6 @@ import torch
18
  import torch.nn as nn
19
  import yaml
20
 
21
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'msst'))
22
-
23
- from models.bs_roformer.bs_roformer import BSRoformer
24
-
25
-
26
  # ---------------------------------------------------------------------------
27
  # Wrapper that exposes only the exportable "core" of BSRoformer
28
  # ---------------------------------------------------------------------------
@@ -34,7 +29,7 @@ class BSRoformerCore(nn.Module):
34
  Output: mask – (B, num_stems, F, T, 2) complex mask as real tensor
35
  """
36
 
37
- def __init__(self, model: BSRoformer):
38
  super().__init__()
39
  self.band_split = model.band_split
40
  self.layers = model.layers
@@ -93,7 +88,7 @@ class BSRoformerCore(nn.Module):
93
  # Helper: replace PoPE attention with standard attention for ONNX export
94
  # ---------------------------------------------------------------------------
95
 
96
- def replace_pope_with_standard_attn(model: BSRoformer):
97
  """Replace PoPE-based flash attention with equivalent standard attention.
98
 
99
  PoPE attention applies: q, k = softplus(q), softplus(k) then rotates by
@@ -209,6 +204,11 @@ def replace_pope_with_standard_attn(model: BSRoformer):
209
 
210
  def load_model(config_path, checkpoint_path):
211
  """Load BS PolarFormer model from config and checkpoint."""
 
 
 
 
 
212
  with open(config_path, 'r') as f:
213
  config = yaml.full_load(f)
214
 
@@ -503,6 +503,7 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
503
  """
504
  import onnx
505
  from onnx import helper, TensorProto, numpy_helper
 
506
 
507
  print(f"\n=== Cascading large Split/Concat ops (max {max_bindings} bindings) ===")
508
  model = onnx.load(onnx_path)
@@ -514,27 +515,38 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
514
  type_map[vi.name] = vi
515
 
516
  def _get_shape(tensor_name):
517
- """Get the shape of a tensor as a list of (dim_value, dim_param) tuples."""
 
 
 
 
 
518
  vi = type_map.get(tensor_name)
519
  if vi is None:
520
  return None
521
  try:
522
  dims = vi.type.tensor_type.shape.dim
523
- return [(d.dim_value, d.dim_param) for d in dims]
 
 
 
 
 
 
524
  except Exception:
525
  return None
526
 
527
  def _make_value_info(name, shape_dims, elem_type=1):
528
  """Create a TensorValueInfoProto with the given shape dimensions.
529
- shape_dims is a list of (dim_value, dim_param) tuples.
530
  elem_type=1 means FLOAT."""
531
  vi = helper.make_tensor_value_info(name, elem_type, [None] * len(shape_dims))
532
- for i, (dv, dp) in enumerate(shape_dims):
533
  dim = vi.type.tensor_type.shape.dim[i]
534
  dim.Clear()
535
  if dp:
536
  dim.dim_param = dp
537
- else:
538
  dim.dim_value = dv
539
  return vi
540
 
@@ -559,7 +571,7 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
559
  for oi, oname in enumerate(output_names):
560
  if oname not in type_map:
561
  shape = list(input_shape)
562
- shape[a] = (sizes[oi], '')
563
  vi = _make_value_info(oname, shape)
564
  new_value_infos.append(vi)
565
  type_map[oname] = vi
@@ -592,7 +604,7 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
592
  shape = list(input_shape)
593
  # Replace the split axis dimension with the group's concrete size
594
  a = axis if axis >= 0 else len(shape) + axis
595
- shape[a] = (group_sizes[gi], '') # concrete dim_value, no dim_param
596
  vi = _make_value_info(gout, shape)
597
  new_value_infos.append(vi)
598
  type_map[gout] = vi
@@ -658,6 +670,28 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
658
 
659
  inputs = list(node.input)
660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  # Group inputs into chunks of max_bindings
662
  groups = []
663
  for i in range(0, n_in, max_bindings):
@@ -673,13 +707,13 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
673
  total = 0
674
  all_concrete = True
675
  for s in shapes:
676
- dv, dp = s[a]
677
- if dp or dv == 0:
678
  all_concrete = False
679
  break
680
  total += dv
681
  if all_concrete:
682
- base[a] = (total, '')
683
  return base
684
  return None
685
 
@@ -745,6 +779,50 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
745
  graph.node.remove(node)
746
  graph.node.extend(nodes_to_add)
747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
  # Add explicit value_info entries for cascade intermediates (as hints).
749
  for vi in new_value_infos:
750
  graph.value_info.append(vi)
@@ -800,6 +878,59 @@ def cascade_large_ops(onnx_path, out_path, max_bindings=7):
800
  else:
801
  print(f" All intermediate tensors have proper shape annotations")
802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
  return out_path
804
 
805
 
 
18
  import torch.nn as nn
19
  import yaml
20
 
 
 
 
 
 
21
  # ---------------------------------------------------------------------------
22
  # Wrapper that exposes only the exportable "core" of BSRoformer
23
  # ---------------------------------------------------------------------------
 
29
  Output: mask – (B, num_stems, F, T, 2) complex mask as real tensor
30
  """
31
 
32
+ def __init__(self, model: nn.Module):
33
  super().__init__()
34
  self.band_split = model.band_split
35
  self.layers = model.layers
 
88
  # Helper: replace PoPE attention with standard attention for ONNX export
89
  # ---------------------------------------------------------------------------
90
 
91
+ def replace_pope_with_standard_attn(model: nn.Module):
92
  """Replace PoPE-based flash attention with equivalent standard attention.
93
 
94
  PoPE attention applies: q, k = softplus(q), softplus(k) then rotates by
 
204
 
205
  def load_model(config_path, checkpoint_path):
206
  """Load BS PolarFormer model from config and checkpoint."""
207
+ msst_root = os.path.join(os.path.dirname(__file__), '..', 'msst')
208
+ if msst_root not in sys.path:
209
+ sys.path.insert(0, msst_root)
210
+ from models.bs_roformer.bs_roformer import BSRoformer
211
+
212
  with open(config_path, 'r') as f:
213
  config = yaml.full_load(f)
214
 
 
503
  """
504
  import onnx
505
  from onnx import helper, TensorProto, numpy_helper
506
+ from collections import defaultdict, deque
507
 
508
  print(f"\n=== Cascading large Split/Concat ops (max {max_bindings} bindings) ===")
509
  model = onnx.load(onnx_path)
 
515
  type_map[vi.name] = vi
516
 
517
  def _get_shape(tensor_name):
518
+ """Get shape dims as (dim_value, dim_param, has_dim_value).
519
+
520
+ Important: in ONNX protobuf, dim_value defaults to 0 when unset.
521
+ We must track whether dim_value is explicitly set, otherwise unknown
522
+ dimensions can be misread as concrete zeros.
523
+ """
524
  vi = type_map.get(tensor_name)
525
  if vi is None:
526
  return None
527
  try:
528
  dims = vi.type.tensor_type.shape.dim
529
+ shape = []
530
+ for d in dims:
531
+ has_dim_value = d.HasField('dim_value')
532
+ dv = d.dim_value if has_dim_value else None
533
+ dp = d.dim_param if d.HasField('dim_param') else ''
534
+ shape.append((dv, dp, has_dim_value))
535
+ return shape
536
  except Exception:
537
  return None
538
 
539
  def _make_value_info(name, shape_dims, elem_type=1):
540
  """Create a TensorValueInfoProto with the given shape dimensions.
541
+ shape_dims is a list of (dim_value, dim_param, has_dim_value) tuples.
542
  elem_type=1 means FLOAT."""
543
  vi = helper.make_tensor_value_info(name, elem_type, [None] * len(shape_dims))
544
+ for i, (dv, dp, has_dim_value) in enumerate(shape_dims):
545
  dim = vi.type.tensor_type.shape.dim[i]
546
  dim.Clear()
547
  if dp:
548
  dim.dim_param = dp
549
+ elif has_dim_value and dv is not None:
550
  dim.dim_value = dv
551
  return vi
552
 
 
571
  for oi, oname in enumerate(output_names):
572
  if oname not in type_map:
573
  shape = list(input_shape)
574
+ shape[a] = (sizes[oi], '', True)
575
  vi = _make_value_info(oname, shape)
576
  new_value_infos.append(vi)
577
  type_map[oname] = vi
 
604
  shape = list(input_shape)
605
  # Replace the split axis dimension with the group's concrete size
606
  a = axis if axis >= 0 else len(shape) + axis
607
+ shape[a] = (group_sizes[gi], '', True) # concrete dim_value, no dim_param
608
  vi = _make_value_info(gout, shape)
609
  new_value_infos.append(vi)
610
  type_map[gout] = vi
 
670
 
671
  inputs = list(node.input)
672
 
673
+ # Filter out 0-sized inputs to avoid creating Concat nodes with ONLY 0-sized inputs
674
+ # (which crashes WebGPU's WGSL compiler).
675
+ valid_inputs = []
676
+ for inp in inputs:
677
+ shape = _get_shape(inp)
678
+ is_zero = False
679
+ if shape:
680
+ a = axis if axis >= 0 else len(shape) + axis
681
+ if a < len(shape):
682
+ dv, dp, has_dv = shape[a]
683
+ if has_dv and dv == 0 and not dp:
684
+ is_zero = True
685
+ if not is_zero:
686
+ valid_inputs.append(inp)
687
+
688
+ # If all inputs were 0-sized, keep at least one so the Concat node is valid
689
+ if not valid_inputs and inputs:
690
+ valid_inputs = [inputs[0]]
691
+
692
+ inputs = valid_inputs
693
+ n_in = len(inputs)
694
+
695
  # Group inputs into chunks of max_bindings
696
  groups = []
697
  for i in range(0, n_in, max_bindings):
 
707
  total = 0
708
  all_concrete = True
709
  for s in shapes:
710
+ dv, dp, has_dv = s[a]
711
+ if dp or not has_dv:
712
  all_concrete = False
713
  break
714
  total += dv
715
  if all_concrete:
716
+ base[a] = (total, '', True)
717
  return base
718
  return None
719
 
 
779
  graph.node.remove(node)
780
  graph.node.extend(nodes_to_add)
781
 
782
+ # Re-establish topological ordering after graph rewrites.
783
+ # Appending new nodes at the end can violate node dependency order.
784
+ def _toposort_nodes(g):
785
+ producers = {}
786
+ for i, node in enumerate(g.node):
787
+ for out in node.output:
788
+ if out:
789
+ producers[out] = i
790
+
791
+ deps = defaultdict(set)
792
+ users = defaultdict(set)
793
+ indegree = [0] * len(g.node)
794
+
795
+ for i, node in enumerate(g.node):
796
+ for inp in node.input:
797
+ if not inp:
798
+ continue
799
+ p = producers.get(inp)
800
+ if p is None or p == i:
801
+ continue
802
+ if p not in deps[i]:
803
+ deps[i].add(p)
804
+ users[p].add(i)
805
+ indegree[i] += 1
806
+
807
+ q = deque([i for i, d in enumerate(indegree) if d == 0])
808
+ order = []
809
+ while q:
810
+ cur = q.popleft()
811
+ order.append(cur)
812
+ for nxt in users[cur]:
813
+ indegree[nxt] -= 1
814
+ if indegree[nxt] == 0:
815
+ q.append(nxt)
816
+
817
+ if len(order) != len(g.node):
818
+ raise RuntimeError("Failed to topologically sort rewritten ONNX graph (cycle or missing dependency).")
819
+
820
+ sorted_nodes = [g.node[i] for i in order]
821
+ del g.node[:]
822
+ g.node.extend(sorted_nodes)
823
+
824
+ _toposort_nodes(graph)
825
+
826
  # Add explicit value_info entries for cascade intermediates (as hints).
827
  for vi in new_value_infos:
828
  graph.value_info.append(vi)
 
878
  else:
879
  print(f" All intermediate tensors have proper shape annotations")
880
 
881
+ # Extra safety check: reject models where a Concat would end up with only
882
+ # statically zero-sized inputs on its concat axis (known to break WebGPU WGSL).
883
+ vi_map2 = {}
884
+ for vi in list(model2.graph.value_info) + list(model2.graph.input) + list(model2.graph.output):
885
+ vi_map2[vi.name] = vi
886
+
887
+ def _shape_from_vi(vi):
888
+ try:
889
+ dims = vi.type.tensor_type.shape.dim
890
+ out = []
891
+ for d in dims:
892
+ has_dv = d.HasField('dim_value')
893
+ dv = d.dim_value if has_dv else None
894
+ dp = d.dim_param if d.HasField('dim_param') else ''
895
+ out.append((dv, dp, has_dv))
896
+ return out
897
+ except Exception:
898
+ return None
899
+
900
+ bad_concat = 0
901
+ for node in model2.graph.node:
902
+ if node.op_type != 'Concat':
903
+ continue
904
+ axis = 0
905
+ for attr in node.attribute:
906
+ if attr.name == 'axis':
907
+ axis = attr.i
908
+ all_zero = True
909
+ for inp in node.input:
910
+ vi = vi_map2.get(inp)
911
+ shape = _shape_from_vi(vi) if vi is not None else None
912
+ if not shape:
913
+ all_zero = False
914
+ continue
915
+ a = axis if axis >= 0 else len(shape) + axis
916
+ if a < 0 or a >= len(shape):
917
+ all_zero = False
918
+ continue
919
+ dv, dp, has_dv = shape[a]
920
+ if dp or not has_dv:
921
+ all_zero = False
922
+ elif dv > 0:
923
+ all_zero = False
924
+ if all_zero or not node.input:
925
+ bad_concat += 1
926
+ print(f" ERROR: problematic Concat '{node.name}' may produce invalid WebGPU shader")
927
+
928
+ if bad_concat:
929
+ raise RuntimeError(
930
+ f"Detected {bad_concat} problematic Concat nodes after cascading; "
931
+ "refuse to output potentially invalid WebGPU model."
932
+ )
933
+
934
  return out_path
935
 
936