File size: 6,738 Bytes
6a22ec9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
- BatchNormalization ∘ Conv -> Conv
- BatchNormalization ∘ ConvTranpose -> ConvTranpose
- BatchNormalization ∘ Gemm -> Gemm
Approach:
Given an inbound operation output: Y = W * X + B
And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps)
The fusion updates the inbound weights as follows:
- W_fused = W * (gamma / std)
- B_fused = (B - μ) * (gamma / std) + β
"""
from abc import ABC, abstractmethod
from typing import Mapping
import numpy as np
from onnxscript import ir
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray:
# Build shape: 1s everywhere except -1 at the target axis
broadcast_shape = [1 if axis != i else -1 for i in range(rank)]
return np.reshape(x, broadcast_shape)
class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
"""Interface for BatchNormalization nodes fusion."""
def __init__(
self,
op_type: str,
name: str | None = None,
remove_nodes: bool = True,
as_function: bool = False,
) -> None:
super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function)
self.op_type = op_type
@abstractmethod
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
"""Return the axis along which BatchNorm scale should be broadcasted."""
def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
batchnorm_node = batchnorm_out.producer()
# Get BatchNorm parameters
gamma, beta, input_mean, input_var = [
inp.const_value.numpy() for inp in batchnorm_node.inputs[1:]
]
# 1e-5 is the default value for epsilon according to
# https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes
default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5)
eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float()
# Compute the scale_factor to update the inbound weights and bias
scale_factor = gamma / np.sqrt(input_var + eps)
# Update inbound weights
inbound_node = inbound_out.producer()
weights = inbound_node.inputs[1].const_value.numpy()
# Reshape scale factor so it is broadcastable
axis = self.get_filters_axis(inbound_node.attributes)
fused_weights = ir.tensor(
weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
)
# Update bias
if len(inbound_node.inputs) > 2:
original_bias = inbound_node.inputs[2].const_value.numpy()
bias_name = inbound_node.inputs[2].name
else:
original_bias = np.zeros_like(input_mean)
bias_name = x.name + "_bias"
fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta)
return op.op(
self.op_type,
inputs=[
x,
op.initializer(fused_weights, name=inbound_node.inputs[1].name),
op.initializer(fused_bias, name=bias_name),
],
attributes=inbound_node.attributes,
)
def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> MatchResult:
del context # Unused
check_result = MatchResult()
inbound_node = inbound_out.producer()
batchnorm_node = batchnorm_out.producer()
# Check that inbound weights + (inbound bias) + batchnorm params are initializers
# and that they are not graph inputs
initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]]
if len(inbound_node.inputs) > 2:
initializers.append(inbound_node.inputs[2])
for initializer in initializers:
if not initializer.is_initializer() or initializer.const_value is None:
return check_result.fail(f"{initializer.name} is not a constant initializer.")
if initializer.is_graph_input():
return check_result.fail(f"{initializer.name} is a graph input.")
return check_result
class FuseBatchNormIntoConv(_FuseBatchNormBase):
"""Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
def __init__(self):
super().__init__("Conv")
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
return 0
def pattern(self, op, x):
return op.BatchNormalization(
op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
_allow_other_inputs=True,
_outputs=["batchnorm_out"],
)
class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
"""Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
def __init__(self):
super().__init__("ConvTranspose")
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
return 1
def pattern(self, op, x):
return op.BatchNormalization(
op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
_allow_other_inputs=True,
_outputs=["batchnorm_out"],
)
class FuseBatchNormIntoGemm(_FuseBatchNormBase):
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
def __init__(self):
super().__init__("Gemm")
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
return (
0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1
)
def pattern(self, op, x):
return op.BatchNormalization(
op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
_allow_other_inputs=True,
_outputs=["batchnorm_out"],
)
fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule()
fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule()
fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule()
def fuse_batchnorm_rule_set() -> RewriteRuleSet:
"""Returns a set of rewrite rules that fuse BatchNormalization nodes
into preceding nodes such as Conv, ConvTranspose, and Gemm.
Returns:
RewriteRuleSet
"""
return RewriteRuleSet(
[
fuse_batchnorm_into_conv_rule,
fuse_batchnorm_into_convtranspose_rule,
fuse_batchnorm_into_gemm_rule,
]
)
|