justinchuby commited on
Commit
650fcdf
·
verified ·
1 Parent(s): 750a4da

Add ReverseSequence rule

Browse files
Files changed (1) hide show
  1. scripts/optimize.py +50 -1
scripts/optimize.py CHANGED
@@ -2,6 +2,7 @@ import onnxscript
2
  import onnx_ir as ir
3
  import onnx_ir.passes.common
4
  import numpy as np
 
5
 
6
 
7
  class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
@@ -64,6 +65,37 @@ class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase):
64
  return op.Identity(x)
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  model = ir.load("model.onnx")
68
 
69
  # Set dynamic axes
@@ -72,7 +104,11 @@ model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
72
 
73
  onnxscript.rewriter.rewrite(
74
  model,
75
- [ReplaceDftWithMatMulRule().rule(), ReplaceSplit().rule(), RemoveCast().rule()],
 
 
 
 
76
  )
77
 
78
  # Change all int32 initializers to int64
@@ -131,6 +167,19 @@ onnxscript.optimizer.optimize(
131
  model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
132
  )
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
136
  model.graph.inputs[0].name = "input"
 
2
  import onnx_ir as ir
3
  import onnx_ir.passes.common
4
  import numpy as np
5
+ import onnxslim
6
 
7
 
8
  class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
 
65
  return op.Identity(x)
66
 
67
 
68
+ class RemoveReversedSequenceFork(onnxscript.rewriter.RewriteRuleClassBase):
69
+ def pattern(self, op, x, y, scale, bias):
70
+ x = op.Transpose(x)
71
+ y = op.Transpose(y)
72
+ x = op.ReverseSequence(x, _allow_other_inputs=True)
73
+ y = op.ReverseSequence(y, _allow_other_inputs=True)
74
+ x = op.Unsqueeze(x, _allow_other_inputs=True)
75
+ y = op.Unsqueeze(y, _allow_other_inputs=True)
76
+ concat = op.Concat(x, y)
77
+ mul = op.Mul(concat, scale)
78
+ add = op.Add(mul, bias)
79
+ return op.Transpose(add)
80
+
81
+ def rewrite(self, op, x, y, scale, bias, **kwargs):
82
+ # x: batch, 511, 96
83
+ neg_one = op.initializer(ir.tensor(np.array([-1], dtype=np.int64)), "neg_one")
84
+ int_64_min = op.initializer(
85
+ ir.tensor(np.array([-9223372036854775808], dtype=np.int64)), "int_64_min"
86
+ )
87
+ # slice
88
+ x = op.Slice(x, neg_one, int_64_min, neg_one, neg_one)
89
+ y = op.Slice(y, neg_one, int_64_min, neg_one, neg_one)
90
+ x = op.Unsqueeze(x, neg_one)
91
+ y = op.Unsqueeze(y, neg_one)
92
+ concat = op.Concat(x, y, axis=3)
93
+ # batch, 511, 96, 2
94
+ mul = op.Mul(concat, scale)
95
+ add = op.Add(mul, bias)
96
+ return op.Transpose(add, perm=[0, 3, 2, 1]) # batch, 2, 96, 511
97
+
98
+
99
  model = ir.load("model.onnx")
100
 
101
  # Set dynamic axes
 
104
 
105
  onnxscript.rewriter.rewrite(
106
  model,
107
+ [
108
+ ReplaceDftWithMatMulRule().rule(),
109
+ ReplaceSplit().rule(),
110
+ RemoveCast().rule(),
111
+ ],
112
  )
113
 
114
  # Change all int32 initializers to int64
 
167
  model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
168
  )
169
 
170
+ print("Slimming model...")
171
+ model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
172
+
173
+ print("Removing reversed sequence fork...")
174
+ onnxscript.rewriter.rewrite(
175
+ model,
176
+ [
177
+ RemoveReversedSequenceFork.rule(),
178
+ ],
179
+ )
180
+
181
+ # Use onnxslim to do shape inference
182
+ model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
183
 
184
  onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
185
  model.graph.inputs[0].name = "input"