Add ReverseSequence rule
Browse files- scripts/optimize.py +50 -1
scripts/optimize.py
CHANGED
|
@@ -2,6 +2,7 @@ import onnxscript
|
|
| 2 |
import onnx_ir as ir
|
| 3 |
import onnx_ir.passes.common
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
|
|
@@ -64,6 +65,37 @@ class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase):
|
|
| 64 |
return op.Identity(x)
|
| 65 |
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
model = ir.load("model.onnx")
|
| 68 |
|
| 69 |
# Set dynamic axes
|
|
@@ -72,7 +104,11 @@ model.graph.outputs[0].shape = ir.Shape(["batch", 6522])
|
|
| 72 |
|
| 73 |
onnxscript.rewriter.rewrite(
|
| 74 |
model,
|
| 75 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
# Change all int32 initializers to int64
|
|
@@ -131,6 +167,19 @@ onnxscript.optimizer.optimize(
|
|
| 131 |
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
|
| 132 |
)
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
|
| 136 |
model.graph.inputs[0].name = "input"
|
|
|
|
| 2 |
import onnx_ir as ir
|
| 3 |
import onnx_ir.passes.common
|
| 4 |
import numpy as np
|
| 5 |
+
import onnxslim
|
| 6 |
|
| 7 |
|
| 8 |
class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase):
|
|
|
|
| 65 |
return op.Identity(x)
|
| 66 |
|
| 67 |
|
| 68 |
+
class RemoveReversedSequenceFork(onnxscript.rewriter.RewriteRuleClassBase):
|
| 69 |
+
def pattern(self, op, x, y, scale, bias):
|
| 70 |
+
x = op.Transpose(x)
|
| 71 |
+
y = op.Transpose(y)
|
| 72 |
+
x = op.ReverseSequence(x, _allow_other_inputs=True)
|
| 73 |
+
y = op.ReverseSequence(y, _allow_other_inputs=True)
|
| 74 |
+
x = op.Unsqueeze(x, _allow_other_inputs=True)
|
| 75 |
+
y = op.Unsqueeze(y, _allow_other_inputs=True)
|
| 76 |
+
concat = op.Concat(x, y)
|
| 77 |
+
mul = op.Mul(concat, scale)
|
| 78 |
+
add = op.Add(mul, bias)
|
| 79 |
+
return op.Transpose(add)
|
| 80 |
+
|
| 81 |
+
def rewrite(self, op, x, y, scale, bias, **kwargs):
|
| 82 |
+
# x: batch, 511, 96
|
| 83 |
+
neg_one = op.initializer(ir.tensor(np.array([-1], dtype=np.int64)), "neg_one")
|
| 84 |
+
int_64_min = op.initializer(
|
| 85 |
+
ir.tensor(np.array([-9223372036854775808], dtype=np.int64)), "int_64_min"
|
| 86 |
+
)
|
| 87 |
+
# slice
|
| 88 |
+
x = op.Slice(x, neg_one, int_64_min, neg_one, neg_one)
|
| 89 |
+
y = op.Slice(y, neg_one, int_64_min, neg_one, neg_one)
|
| 90 |
+
x = op.Unsqueeze(x, neg_one)
|
| 91 |
+
y = op.Unsqueeze(y, neg_one)
|
| 92 |
+
concat = op.Concat(x, y, axis=3)
|
| 93 |
+
# batch, 511, 96, 2
|
| 94 |
+
mul = op.Mul(concat, scale)
|
| 95 |
+
add = op.Add(mul, bias)
|
| 96 |
+
return op.Transpose(add, perm=[0, 3, 2, 1]) # batch, 2, 96, 511
|
| 97 |
+
|
| 98 |
+
|
| 99 |
model = ir.load("model.onnx")
|
| 100 |
|
| 101 |
# Set dynamic axes
|
|
|
|
| 104 |
|
| 105 |
onnxscript.rewriter.rewrite(
|
| 106 |
model,
|
| 107 |
+
[
|
| 108 |
+
ReplaceDftWithMatMulRule().rule(),
|
| 109 |
+
ReplaceSplit().rule(),
|
| 110 |
+
RemoveCast().rule(),
|
| 111 |
+
],
|
| 112 |
)
|
| 113 |
|
| 114 |
# Change all int32 initializers to int64
|
|
|
|
| 167 |
model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
|
| 168 |
)
|
| 169 |
|
| 170 |
+
print("Slimming model...")
|
| 171 |
+
model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
|
| 172 |
+
|
| 173 |
+
print("Removing reversed sequence fork...")
|
| 174 |
+
onnxscript.rewriter.rewrite(
|
| 175 |
+
model,
|
| 176 |
+
[
|
| 177 |
+
RemoveReversedSequenceFork.rule(),
|
| 178 |
+
],
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Use onnxslim to do shape inference
|
| 182 |
+
model = ir.from_proto(onnxslim.slim(ir.to_proto(model)))
|
| 183 |
|
| 184 |
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model)
|
| 185 |
model.graph.inputs[0].name = "input"
|