| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import tensorflow as tf |
| from tensorflow.keras import layers, models |
| import numpy as np |
|
|
|
|
| def _fold_batch_norm(weights, bias, gamma, beta, moving_mean, moving_var, epsilon, layer_type): |
| """ |
| Implements equation for Backward BN weights folding. |
| Args: |
| weights: original weights |
| bias: original bias |
| gamma: multiplicative trainable parameter of the batch normalisation. Per-channel |
| beta: additive trainable parameter of the batch normalisation. Per-channel |
| moving_mean: moving average of the layer output. Used for centering the samples distribution after |
| batch normalisation |
| moving_var: moving variance of the layer output. Used for reducing the samples distribution |
| epsilon: a small number to void dividing by 0 |
| layer_type: layer type (dense, conv2d or depthwiseconv2d) |
| Returns: folded weights and bias |
| """ |
| if bias is None: |
| bias = np.zeros_like(moving_mean) |
| std = np.sqrt(moving_var + epsilon) |
| if layer_type == 'Conv2D': |
| new_weights = weights * (gamma / std) |
| new_bias = beta + (bias - moving_mean) * (gamma / std) |
| elif layer_type == 'DepthwiseConv2D': |
| gamma_std = (gamma / std).reshape(1, 1, -1, 1) |
| new_weights = weights * gamma_std |
| new_bias = beta + (bias - moving_mean) * (gamma / std) |
| elif layer_type == 'Dense': |
| new_weights = weights * (gamma / std) |
| new_bias = beta + (bias - moving_mean) * (gamma / std) |
| else: |
| raise ValueError("Unsupported layer type for BN folding.") |
| return new_weights, new_bias |
|
|
|
|
| def fold_bn(model): |
| """ |
| Search for BN to fold in Backward direction and fold them with _fold_batch_norm function |
| |
| Args: |
| model: input keras model |
| |
| Returns: a new keras model, with BN folded in backward direction |
| """ |
| |
| |
| new_layers = {} |
| |
| tensor_map = {} |
|
|
| |
| for input_tensor in model.inputs: |
| new_input = layers.Input(shape=input_tensor.shape[1:], name=input_tensor.name.split(":")[0]) |
| tensor_map[id(input_tensor)] = new_input |
|
|
| |
| layer_name_list = [layer.name for layer in model.layers] |
|
|
| |
| tensor_id_to_layer_name = {} |
| |
| for layer in model.operations: |
| for node in layer._inbound_nodes: |
| for t in node.output_tensors: |
| tensor_id_to_layer_name[id(t)] = layer.name |
|
|
| |
| folded_bn_layers = set() |
|
|
| |
| |
| for layer in model.operations: |
| if isinstance(layer, layers.InputLayer): |
| continue |
|
|
| |
| inbound_tensors = [] |
| for node in layer._inbound_nodes: |
| for t in node.input_tensors: |
| inbound_tensors.append(tensor_map.get(id(t), t)) |
| if not inbound_tensors: |
| continue |
| inbound = inbound_tensors[0] if len(inbound_tensors) == 1 else inbound_tensors |
|
|
| |
| is_foldable = isinstance(layer, (layers.Conv2D, layers.DepthwiseConv2D, layers.Dense)) |
| |
| next_bn_layer = None |
|
|
| for out_node in layer._outbound_nodes: |
| if isinstance(out_node.operation, layers.BatchNormalization): |
| next_bn_layer = out_node.operation |
| break |
|
|
| if is_foldable and next_bn_layer and next_bn_layer.name not in folded_bn_layers: |
| |
| weights = layer.get_weights() |
| W = weights[0] |
| b = weights[1] if len(weights) > 1 else None |
| gamma, beta, moving_mean, moving_var = next_bn_layer.get_weights() |
| epsilon = next_bn_layer.epsilon |
| if isinstance(layer, layers.Conv2D): |
| W_shape = W.shape |
| W = W.reshape(-1, W_shape[-1]) |
| new_W, new_b = _fold_batch_norm(W, b, gamma, beta, moving_mean, moving_var, epsilon, 'Conv2D') |
| new_W = new_W.reshape(W_shape) |
| elif isinstance(layer, layers.DepthwiseConv2D): |
| new_W, new_b = _fold_batch_norm(W, b, gamma, beta, moving_mean, moving_var, epsilon, 'DepthwiseConv2D') |
| elif isinstance(layer, layers.Dense): |
| new_W, new_b = _fold_batch_norm(W, b, gamma, beta, moving_mean, moving_var, epsilon, 'Dense') |
| else: |
| raise ValueError("Unsupported layer type for BN folding.") |
|
|
| |
| config = layer.get_config() |
| config['use_bias'] = True |
| new_layer = type(layer).from_config(config) |
| x = new_layer(inbound) |
| new_layer.set_weights([new_W, new_b]) |
| new_layers[layer.name] = new_layer |
|
|
| |
| for node in next_bn_layer._inbound_nodes: |
| for t in node.output_tensors: |
| tensor_map[id(t)] = x |
| folded_bn_layers.add(next_bn_layer.name) |
| elif layer.name in folded_bn_layers: |
| |
| continue |
| elif layer.name not in layer_name_list: |
| |
| if isinstance(inbound, list): |
| if len(inbound) == 2: |
| x = layer(inbound[0], inbound[1]) |
| else: |
| x = layer(inbound) |
| new_layers[layer.name] = layer |
| |
| for node in layer._inbound_nodes: |
| for t in node.output_tensors: |
| tensor_map[id(t)] = x |
| else: |
| |
| config = layer.get_config() |
| new_layer = type(layer).from_config(config) |
| x = new_layer(inbound) |
| if layer.get_weights(): |
| new_layer.set_weights(layer.get_weights()) |
| new_layers[layer.name] = new_layer |
| |
| for node in layer._inbound_nodes: |
| for t in node.output_tensors: |
| tensor_map[id(t)] = x |
|
|
| |
| new_outputs = [tensor_map[id(out)] for out in model.outputs] |
| new_inputs = [tensor_map[id(inp)] for inp in model.inputs] |
| new_model = models.Model(inputs=new_inputs, outputs=new_outputs) |
| return new_model |