File size: 7,403 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np


def _fold_batch_norm(weights, bias, gamma, beta, moving_mean, moving_var, epsilon, layer_type):
    """
         Implements equation for Backward BN weights folding.
         Args:
              weights: original weights
              bias: original bias
              gamma: multiplicative trainable parameter of the batch normalisation. Per-channel
              beta: additive trainable parameter of the batch normalisation. Per-channel
              moving_mean: moving average of the layer output. Used for centering the samples distribution after
              batch normalisation
              moving_var: moving variance of the layer output. Used for reducing the samples distribution
              epsilon: a small number to void dividing by 0
              layer_type: layer type (dense, conv2d or depthwiseconv2d)
         Returns: folded weights and bias
    """
    if bias is None:
        bias = np.zeros_like(moving_mean)
    std = np.sqrt(moving_var + epsilon)
    if layer_type == 'Conv2D':
        new_weights = weights * (gamma / std)
        new_bias = beta + (bias - moving_mean) * (gamma / std)
    elif layer_type == 'DepthwiseConv2D':
        gamma_std = (gamma / std).reshape(1, 1, -1, 1)
        new_weights = weights * gamma_std
        new_bias = beta + (bias - moving_mean) * (gamma / std)
    elif layer_type == 'Dense':
        new_weights = weights * (gamma / std)
        new_bias = beta + (bias - moving_mean) * (gamma / std)
    else:
        raise ValueError("Unsupported layer type for BN folding.")
    return new_weights, new_bias


def fold_bn(model):
    """
        Search for BN to fold in Backward direction and fold them with _fold_batch_norm function

        Args:
            model: input keras model

        Returns: a new keras model, with BN folded in backward direction
    """
    
    # Map from original layer name to new layer instance
    new_layers = {}
    # Map from original tensor id to new tensor
    tensor_map = {}

    # Create new Input layers
    for input_tensor in model.inputs:
        new_input = layers.Input(shape=input_tensor.shape[1:], name=input_tensor.name.split(":")[0])
        tensor_map[id(input_tensor)] = new_input

    # layer name list
    layer_name_list = [layer.name for layer in model.layers]

    # Build a mapping from tensor id to producing layer name
    tensor_id_to_layer_name = {}
    # model.operations rather than model.layers to have operators in the list
    for layer in model.operations: #model.layers:
        for node in layer._inbound_nodes:
            for t in node.output_tensors:
                tensor_id_to_layer_name[id(t)] = layer.name

    # Track which BN layers are folded (to skip them)
    folded_bn_layers = set()

    # Traverse layers in order
    # model.operations rather than model.layers to have operators in the list
    for layer in model.operations: #model.layers:
        if isinstance(layer, layers.InputLayer):
            continue

        # Get input tensors for this layer
        inbound_tensors = []
        for node in layer._inbound_nodes:
            for t in node.input_tensors:
                inbound_tensors.append(tensor_map.get(id(t), t))
        if not inbound_tensors:
            continue
        inbound = inbound_tensors[0] if len(inbound_tensors) == 1 else inbound_tensors

        # Check if this layer is foldable and is immediately followed by BN
        is_foldable = isinstance(layer, (layers.Conv2D, layers.DepthwiseConv2D, layers.Dense))
        # Find if any layer takes this layer's output as input and is BN
        next_bn_layer = None

        for out_node in layer._outbound_nodes:
            if isinstance(out_node.operation, layers.BatchNormalization):
                next_bn_layer = out_node.operation
                break

        if is_foldable and next_bn_layer and next_bn_layer.name not in folded_bn_layers:
            # Fold BN
            weights = layer.get_weights()
            W = weights[0]
            b = weights[1] if len(weights) > 1 else None
            gamma, beta, moving_mean, moving_var = next_bn_layer.get_weights()
            epsilon = next_bn_layer.epsilon
            if isinstance(layer, layers.Conv2D):
                W_shape = W.shape
                W = W.reshape(-1, W_shape[-1])
                new_W, new_b = _fold_batch_norm(W, b, gamma, beta, moving_mean, moving_var, epsilon, 'Conv2D')
                new_W = new_W.reshape(W_shape)
            elif isinstance(layer, layers.DepthwiseConv2D):
                new_W, new_b = _fold_batch_norm(W, b, gamma, beta, moving_mean, moving_var, epsilon, 'DepthwiseConv2D')
            elif isinstance(layer, layers.Dense):
                new_W, new_b = _fold_batch_norm(W, b, gamma, beta, moving_mean, moving_var, epsilon, 'Dense')
            else:
                raise ValueError("Unsupported layer type for BN folding.")

            # Create new layer with folded weights
            config = layer.get_config()
            config['use_bias'] = True
            new_layer = type(layer).from_config(config)
            x = new_layer(inbound)
            new_layer.set_weights([new_W, new_b])
            new_layers[layer.name] = new_layer

            # Map BN's output tensor to this new output
            for node in next_bn_layer._inbound_nodes:
                for t in node.output_tensors:
                    tensor_map[id(t)] = x
            folded_bn_layers.add(next_bn_layer.name)
        elif layer.name in folded_bn_layers:
            # This is a BN layer that was already folded, skip it
            continue
        elif layer.name not in layer_name_list:
            # Keras or TF ops and not type Layers
            if isinstance(inbound, list):
                if len(inbound) == 2:
                    x = layer(inbound[0], inbound[1])
            else:
                x = layer(inbound)
            new_layers[layer.name] = layer
            # Map this layer's output tensors
            for node in layer._inbound_nodes:
                for t in node.output_tensors:
                    tensor_map[id(t)] = x
        else:
            # Just recreate the layer
            config = layer.get_config()
            new_layer = type(layer).from_config(config)
            x = new_layer(inbound)
            if layer.get_weights():
                new_layer.set_weights(layer.get_weights())
            new_layers[layer.name] = new_layer
            # Map this layer's output tensors
            for node in layer._inbound_nodes:
                for t in node.output_tensors:
                    tensor_map[id(t)] = x

    # Build new model
    new_outputs = [tensor_map[id(out)] for out in model.outputs]
    new_inputs = [tensor_map[id(inp)] for inp in model.inputs]
    new_model = models.Model(inputs=new_inputs, outputs=new_outputs)
    return new_model