File size: 25,392 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import tensorflow as tf
import keras
from keras import layers
from keras.models import Model
import numpy as np
import re
from keras.ops import clip, relu

from .bn_folding import fold_bn
from .network_parsing_utils import (get_outbound_nodes, clone_function, get_output_layers_names, node_type, node_name, 
                                   node_config, node_get_weights, node_set_weights, node_activation, layer_type, 
                                   tensor_inbound_node_name, history_operation_class_name)

CLE_NEUTRAL_LAYERS_NAMES = ["ReLU", "PreLU", "Dropout", "ZeroPadding2D"]
RELU6_SAT_UP = 6.0


def _is_neutral_layer(node):
    """

        returns True if node is among so called 'neutral layers' list from equalization point of view or if node is a

        ReLU or ReLU6

        Args:

            : the Keras node we want to analyse



        Returns: a boolean indicating if the node is considered as 'neutral' for cross-layer equalization.

    """

    if node_type(node) in CLE_NEUTRAL_LAYERS_NAMES:
        return True
    elif node_type(node) == "Activation":
        if node_activation(node) in ['relu', 'relu6']:
            return True
        else:
            return False
    else:
        return False


def _is_relu6(node):
    """

        returns True if node is ReLU6

        Args:

            node: the Keras node we want to analyse



        Returns: a boolean indicating if the node is a relu6

    """

    if node_type(node) == "Activation":
        if node_activation(node) == "relu6":
            return True
    elif node_type(node) == "ReLU":
        if 'max_value' in node_config(node):
            if node_config(node)['max_value'] is None:
                return False
            elif int(node_config(node)['max_value']) == RELU6_SAT_UP:
                return True
            else:
                return False
        else:
            return False
    else:
        return False


def _bn_parameters(model):
    """

        returns a dictionary with Batch Norm parameters each time a BN immediately follows a DW. To be called before

        folding of course. It will be used for bias absorption later on

        Args:

            model: the Keras model before folding



        Returns: a dictionary with Batch Norm parameters

    """

    bn_parameters_dict = {}

    for i, layer in enumerate(model.layers):
        if layer_type(layer) == "DepthwiseConv2D":
            out_nodes, n_out_nodes, out_nodes_type, out_nodes_names = get_outbound_nodes(layer)
            # controls that DW and BN are sequential otherwise algo undefined
            if n_out_nodes == 1 and out_nodes_type[0] == "BatchNormalization":
                # store name previous DW and gamma, beta
                bn_parameters_dict[layer.name] = [node_get_weights(out_nodes[0])[0], node_get_weights(out_nodes[0])[1]]

    return bn_parameters_dict


def _high_bias_absorption(model, coupled_index, bn_params_dict, inv_s, n=3):
    """

        implement bias absorption as defined in the https://arxiv.org/abs/2201.08442 paper.

        Args:

            model: the Keras model after cross-layer equalization was executed

            coupled_index: index of couple DW+Conv2d on which was applied cross-layer equalization

            bn_params_dict: a dictionary with Batch Norm parameters for the original model

            inv_s: inverse of 's' (equalization coefficient) in reference paper.

            n: parameter to approximate Gaussian distribution width



        Returns: a dictionary with Batch Norm parameters



    """

    for k, couple_layer_idx in enumerate(coupled_index):
        name_dw = model.layers[couple_layer_idx[0]].name
        # handle the case where BN was folded, and we append the tensor name with '_bn_folded' but not in
        # bn_params_dict. Otherwise, the split keeps name_dw unchanged.
        name_dw = name_dw.split('_bn_folded')[0]

        if name_dw in bn_params_dict:
            gamma = bn_params_dict[name_dw][0] * inv_s[k]
            beta = bn_params_dict[name_dw][1] * inv_s[k]
            c = tf.nn.relu(beta - n*gamma).numpy()

            # there is a potential issue when too many samples of the activations are above saturation point. 
            # In this case, the simplifying assumptions taken by the reference paper are no longer valid from a math 
            # point of view. In this case, we disable bias absorption by setting 'c' to 0 for the corresponding channels
            sat_level = RELU6_SAT_UP * np.array(inv_s[k])
            for q, sat in enumerate(sat_level):
                if beta[q] + n*gamma[q] >= sat:
                    c[q] = 0

            w1 = model.layers[couple_layer_idx[0]].get_weights()[0]
            b1 = model.layers[couple_layer_idx[0]].get_weights()[1]
            new_b1 = b1 - c

            w2 = model.layers[couple_layer_idx[1]].get_weights()[0]
            b2 = model.layers[couple_layer_idx[1]].get_weights()[1]
            # have ch_in first
            w2_tr = np.transpose(w2, (3, 0, 1, 2))
            w2_tr_c = [np.sum(c * channel) for k, channel in enumerate(w2_tr)]
            new_b2 = w2_tr_c + b2

            model.layers[couple_layer_idx[0]].set_weights([w1, new_b1])
            model.layers[couple_layer_idx[1]].set_weights([w2, new_b2])

    return model


def _active_number_of_nodes(list_node):
    """

        removes 'ghost' nodes no longer linked in the graph



        Args:

            list_node: list of node at a given place in the network graph



        Returns: the number of active nodes in a list after filtering these 'ghost' tensors out.

    """

    list_node_filtered = []

    unique_names = np.unique([node_name(node) for node in list_node]).tolist()
    filtrered_t_names = unique_names
    for name_i in unique_names:
        for name_j in unique_names:
            if name_j == name_i + '_bn_folded':
                filtrered_t_names.remove(name_i)

    for member in list_node:
        if node_name(member) in filtrered_t_names:
            list_node_filtered.append(member)

    return list_node_filtered


def _couple_names_and_indexes(model):
    """

           Returns a list of DW/Conv2d couple names when candidate to equalization, and the list of DW/Conv2d

           corresponding indexes. To finish returns the list of ReLU6 layer names when in between DW and Conv2d



           Args:

               model: model after batch norm folding



        Returns: candidate couples for cross-layer equalization index, names and relu6 layer names

    """

    model_layer_coupled_names = []
    model_layer_coupled_index = []
    relu6_layer_names = []

    for i, layer in enumerate(model.layers):
        if layer_type(layer) == "DepthwiseConv2D":
            out_nodes_first, _, _, _ = get_outbound_nodes(layer)
            first_level_nodes = _active_number_of_nodes(out_nodes_first)
            # check that DW and Conv2D or activation are sequential otherwise equalization is anyway not specified
            if len(first_level_nodes) == 1:
                if node_type(first_level_nodes[0]) == "Conv2D":
                    model_layer_coupled_names.append([layer.name, node_name(first_level_nodes[0])])
                elif _is_neutral_layer(first_level_nodes[0]):
                    out_nodes_second, _, _, _ = get_outbound_nodes(first_level_nodes[0])
                    second_level_nodes = _active_number_of_nodes(out_nodes_second)
                    # check that Conv2D is sequential otherwise equalization is anyway not specified
                    if len(second_level_nodes) == 1 and node_type(second_level_nodes[0]) == "Conv2D":
                        model_layer_coupled_names.append([layer.name, node_name(second_level_nodes[0])])
                        if _is_relu6(first_level_nodes[0]):
                            relu6_layer_names.append(node_name(first_level_nodes[0]))

    for name_layer in model_layer_coupled_names:
        sub_list = [i for i, layer in enumerate(model.layers) if layer.name == name_layer[0]]
        sub_list.append([i for i, layer in enumerate(model.layers) if layer.name == name_layer[1]][0])
        model_layer_coupled_index.append(sub_list)

    return model_layer_coupled_names, model_layer_coupled_index, relu6_layer_names


def choose_tensors_when_multiple_outputs(layer_input_tensor, layer_input_signature):

    layer_input_selection = []
    list_signature_names = []

    if type(layer_input_signature) is list:
        for elem in layer_input_signature:
            if hasattr(elem, 'name'):
                list_signature_names.append(elem.name)
    else:
        list_signature_names = [layer_input_signature.name]

    for elem in layer_input_tensor:
        if type(elem) is tuple:
            for sub_elem in elem:
                if sub_elem.name in list_signature_names:
                    layer_input_selection.append(sub_elem)
        else:
            layer_input_selection.append(elem)

    return layer_input_selection


def reorder_multiple_inputs_tensors(layer=None, tensor_name_list=None):
    """

           When a layer has more than one input, we have to re-order its list of input tensors, so that they match with

           actual network graph connections. network_dict dictionary of tensors gives the list of inputs for each layer 

           but it is unordered, and sometimes order matters like for concatenation for example.

           Special care to be taken when a tensor historically came from a BN. After folding the history tensor should

           be linked in the graph.

           Returns a list of ordered input tensor names for a given layer, matching network connections



           Args:

               layer: the layer under consideration, for which we need to order the list of input tensors

               tensor_name_list: list of layer input tensor we may want to re-order



           Returns: a re-ordered list of input tensors names for the layer considered

    """

    history_class_input = [history_operation_class_name(layer.input[i]) for i in range(len(layer.input))]
    if 'BatchNormalization' in history_class_input:
        return tensor_name_list
    else:
        return [tensor_inbound_node_name(layer.input[i]) for i in range(len(layer.input))]


def insert_layer_in_graph(model, layer_list, insert_layer, inv_scale, insert_layer_name=None, position='replace'):
    """

        Returns a model where some layers (layer_List) have been replaced by a new layer type 'insert_layer' with

        as parameter an element of 'inv_scale'



        Args:

            model: keras model after cross-layer equalization and bias absorption

            layer_list: list of layer names we want to replace in the graph

            insert_layer: the layer we want to insert in replacement in the graph

            inv_scale: inverse of 's' (equalization coefficient) as described in https://arxiv.org/abs/2201.08442 paper.

            insert_layer_name: name of the layer we insert. Not used at the moment

            position: could be 'replace', 'after', 'before'. Always 'replace' for cross-layer equalization



        Returns: a keras model with specified layers replaced by new insert_layer

    """

    # early exit
    if not layer_list:
        return model
    # Auxiliary dictionary to describe the network graph
    network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}

    # Set the input layers of each layer. We parse the network using model.operations rather than model.layers that
    # contains only objects of class keras.layers and no longer operators.
    for layer in model.operations:
        out_layers_names = get_output_layers_names(layer)
        for name in out_layers_names:
            if name not in network_dict['input_layers_of']:
                network_dict['input_layers_of'].update(
                        {name: [layer.name]})
            else:
                if layer.name not in network_dict['input_layers_of'][name]:
                    network_dict['input_layers_of'][name].append(layer.name)

    for layer in model.operations[1:]:
        in_tensor_list = network_dict['input_layers_of'][layer.name]
        if len(in_tensor_list) > 1:
            network_dict['input_layers_of'][layer.name] = reorder_multiple_inputs_tensors(layer=layer,
                                                                                          tensor_name_list=in_tensor_list)

    # Set the output tensor of the input layer
    if len(model.input) == 1:
        network_dict['new_output_tensor_of'].update({model.layers[0].name: model.input[0]})
    else:
        network_dict['new_output_tensor_of'].update({model.layers[0].name: model.input})
    # Iterate over all layers after the input
    model_outputs = []
    count = 0

    # actual layer name list
    layer_name_list = [layer.name for layer in model.layers]

    # For graph modifications we again parse the network using model.operations rather than model.layers that
    # contains only objects of class keras.layers and no longer operators.
    for layer in model.operations[1:]:
        # Determine input tensors
        layer_input = [network_dict['new_output_tensor_of'][layer_aux]
                       for layer_aux in network_dict['input_layers_of'][layer.name]]
        layer_input = choose_tensors_when_multiple_outputs(layer_input, layer.input)

        if len(layer_input) == 1:
            layer_input = layer_input[0]

        # Insert layer if name matches
        if layer.name in layer_list:
            if position == 'replace':
                x = layer_input
            elif position == 'after':
                x = layer(layer_input)
            elif position == 'before':
                pass
            else:
                raise ValueError('position must be: before, after or replace')

            if insert_layer:
                if type(insert_layer) is list:
                    new_layer = insert_layer[count]
                    x = new_layer(x)
                elif insert_layer.__class__.__name__ == 'ReLU':
                    new_layer = insert_layer()
                    new_layer._name = '{}_{}'.format(layer.name, 'modified_to_relu')
                    x = new_layer(x)
                elif (insert_layer.__class__.__name__ == 'function' or
                      insert_layer.__class__.__name__ == 'cython_function_or_method'):
                    # adaptive clip
                    x = insert_layer(t=x, invs=inv_scale[count])
                else:
                    pass
                count = count + 1

            if position == 'before':
                x = layer(x)
        else:
            if layer.__class__.__name__ == 'TFOpLambda' or layer.__class__.__name__ == 'SlicingOpLambda':
                print("Lamdba layer detected")
            elif layer.name not in layer_name_list:
            # Keras or TF ops and not type Layers
                if isinstance(layer_input, list):
                    if len(layer_input) == 2:
                        x = layer(layer_input[0], layer_input[1])
                else:
                    x = layer(layer_input)
            else:
                x = layer(layer_input)

        # Set new output tensor (the original one, or the one of the inserted layer)
        network_dict['new_output_tensor_of'].update({layer.name: x})

        # Save tensor in output list if it is output in initial model at origin, if layer_name
        if layer.name in model.output_names:
            model_outputs.append(x)

    return Model(inputs=model.input, outputs=model_outputs)


def _cross_layer_equalisation(model, coupled_index):
    """

        Returns a model where couple layers weights are equalized as described in https://arxiv.org/abs/2201.08442 paper



        Args:

            model: keras model after folding

            coupled_index: index of all the couples DW/Conv2d eligible to equalisation



        Returns: a model with weights and bias updated by cross-layer equalization, and the list of inverse 

        equalisation coefficients.

    """

    eps = 0.0
    list_inv_s = []

    for couple_layer_idx in coupled_index:
        w1 = model.layers[couple_layer_idx[0]].get_weights()[0]
        b1 = model.layers[couple_layer_idx[0]].get_weights()[1]
        # have ch_out first
        w1_tr = np.transpose(w1, (2, 0, 1, 3))

        w2 = model.layers[couple_layer_idx[1]].get_weights()[0]
        b2 = model.layers[couple_layer_idx[1]].get_weights()[1]
        # have ch_in first
        w2_tr = np.transpose(w2, (2, 0, 1, 3))

        # vector s calculation
        r1 = [np.max(e) - np.min(e) for e in w1_tr]
        r2 = [np.max(e) - np.min(e) for e in w2_tr]
        s = [1/(r2[k] + eps) * np.sqrt(r1[k] * r2[k]) for k in range(len(r1))]

        # Treat the corner case where s(k) == 0 in this case it would be impossible to calculate 1/s(k)
        # In case r1(k) was null we can set s(k) to 1 because there is no need in this case to scale down this channel
        # weights, since in any case they are null
        for idx, e in enumerate(s):
            if e == 0 and r1[idx] == 0:
                s[idx] = 1

        inv_s = [1/(e + eps) for e in s]
        list_inv_s.append(inv_s)

        new_w1_tr = [inv_s[k]*channel for k, channel in enumerate(w1_tr)]
        new_w1 = np.array(np.transpose(new_w1_tr, (1, 2, 0, 3)))
        new_b1 = inv_s * b1

        new_w2_tr = [s[k]*channel for k, channel in enumerate(w2_tr)]
        new_w2 = np.array(np.transpose(new_w2_tr, (1, 2, 0, 3)))

        model.layers[couple_layer_idx[0]].set_weights([new_w1, new_b1])
        model.layers[couple_layer_idx[1]].set_weights([new_w2, b2])

    return model, list_inv_s


def _zero_irrelevant_channels(model, min_weights_th, ct_value=0.0):
    """

        Returns a model with weights arbitrarily set to constant value typically 0, if all weights corresponding to a

        given output channel are below 'min_weight_th' in absolute value. Restricted to Conv2d and DW.

        This helps reducing possible bias saturation issue at quantization, when weights channel scale is very small



        Args:

            model: keras model after batch normalisation folding

            min_weights_th: arbitrary threshold under which we consider current weights to be replaced by 'ct_value'

            ct_value: constant value set to the weights when they are < min_weights_th for a given channel. For

            this application ct_value is always set to 0.0



        Returns: the keras model with weights updated



    """

    for layer in model.layers:

        if layer.__class__.__name__ == 'Functional':
            _zero_irrelevant_channels(layer, min_weights_th)
        if layer.__class__.__name__ in ("Conv2D", "DepthwiseConv2D"):
            # weights
            bias_exist = len(layer.get_weights()) == 2
            if bias_exist:
                w = layer.get_weights()[0]
                b = layer.get_weights()[1]
            else:
                w = layer.get_weights()[0]
            if layer.__class__.__name__ == "DepthwiseConv2D":
                # have ch_out first
                w = np.transpose(w, (2, 0, 1, 3))
                for i, we in enumerate(w):
                    if np.abs(np.min(we)) < min_weights_th and np.abs(np.max(we)) < min_weights_th:
                        w[i] = ct_value * np.ones((w.shape[1], w.shape[2], w.shape[3]))
                w = np.transpose(w, (1, 2, 0, 3))
                if bias_exist:
                    layer.set_weights([w, b])
                else:
                    layer.set_weights([w])
            elif layer.__class__.__name__ == "Conv2D":
                # have ch_out first
                w = np.transpose(w, (3, 0, 1, 2))
                for i, we in enumerate(w):
                    if np.abs(np.min(we)) < min_weights_th and np.abs(np.max(we)) < min_weights_th:
                        w[i] = ct_value * np.ones((w.shape[1], w.shape[2], w.shape[3]))
                w = np.transpose(w, (1, 2, 3, 0))
                if bias_exist:
                    layer.set_weights([w, b])
                else:
                    layer.set_weights([w])

    return model


@keras.saving.register_keras_serializable()
class STCustomClip(keras.layers.Layer):
    def __init__(self, name=None, min_vector=None, max_vector=None, **kwargs):
        # important to add **kwargs if super().get_config() is called in get_config() because it brings 
        # parameters defined in kwargs
        super().__init__(name=name, **kwargs)
        self.min_vector = min_vector
        self.max_vector = max_vector

    def call(self, inputs):
        return keras.ops.clip(x=inputs, x_min=self.min_vector, x_max=self.max_vector)

    def get_config(self):
        config = super().get_config()
        config.update({"min_vector": self.min_vector})
        config.update({"max_vector": self.max_vector})
        return config


def _adaptive_clip_per_channel(t=None, invs=None):
    """

        Returns a layer for adaptive channel clipping whose level is given through 'invs'

        Restricted to ReLU6



        Args:

            t: a Keras tensor input of the adaptive clip per channel layer

            invs: list of equalisation coefficients as described in https://arxiv.org/abs/2201.08442 reference paper



        Returns:

            A tensorflow layer for adaptive clipping per-channel

    """
    nb_ch_out = int(t.shape[-1])
    ch_sat_level = [RELU6_SAT_UP*k for k in invs]
    scale = (np.max(ch_sat_level) - np.min(ch_sat_level)) / 65535
    ch_sat_level = np.round(ch_sat_level / scale) * scale

    # although not useful from a math point of view since the following clip has clip_min == 0, the addition of this
    # relu before the clip will make the interpreter understand it needs to fuse it with previous layer which helps
    # reducing the dynamic range of the layer output and thus to find a smaller scale and eventually reduce the
    # quantization noise.
    name_layer = 'ReLU_' + t.name
    custom_activ = layers.ReLU(name=name_layer)(t)

    name_layer = 'ST_Custom_Clip_' + t.name
    # important to cast to lists otherwise issue at model loading because from_config() expects basic python types and
    # not np types
    custom_activ = STCustomClip(name=name_layer,
                                min_vector=np.zeros(nb_ch_out).tolist(),
                                max_vector=ch_sat_level.tolist())(custom_activ)

    return custom_activ


def model_formatting_ptq_per_tensor(model_origin):
    """

        Returns a keras model after all the PTQ optimization chain was executed:

            - batch norm folding

            - zeroing irrelevant channels (too weak)

            - cross layer equalisation (CLE)

            - bias absorption

            - insertion of the adaptive clip layers wherever needed



        Args:

            model_origin: the original Keras model



        Returns:

            A Keras model optimized for subsequent per-tensor quantization

    """

    # keep in memory BN parameters for future bias absorption
    bn_params_dict = _bn_parameters(model_origin)

    # BN folding
    model_folded = fold_bn(model_origin)
    #bw_bn_folding(model_origin, epsilon=1e-3, dead_channel_th=1e-10)

    # zeroing some channels to avoid bias saturation at quantization
    model_folded = _zero_irrelevant_channels(model_folded, min_weights_th=1e-10)

    # extract layer couples names and indexes for equalization
    layer_coupled_names, layer_coupled_index, layer_to_replace_names = _couple_names_and_indexes(model_folded)

    # performs reference paper cross-layer equalization on selected couples
    model_cle, list_inv_s = _cross_layer_equalisation(model=model_folded, coupled_index=layer_coupled_index)

    # performs bias absorption, which is optional
    model_cle = _high_bias_absorption(model=model_cle, coupled_index=layer_coupled_index, inv_s=list_inv_s,
                                     bn_params_dict=bn_params_dict, n=3)

    # insert adaptive channel clipping layers at the right places in the graph
    model_cle = insert_layer_in_graph(model=model_cle, layer_list=layer_to_replace_names,
                                      insert_layer=_adaptive_clip_per_channel, inv_scale=list_inv_s,
                                      insert_layer_name=None,
                                      position='replace')
    return model_cle