|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
from __future__ import print_function |
|
|
|
|
|
import paddle |
|
|
import paddle.nn as nn |
|
|
|
|
|
from typing import List |
|
|
|
|
|
|
|
|
def get_bn_running_state_names(model: nn.Layer) -> List[str]: |
|
|
"""Get all bn state full names including running mean and variance |
|
|
""" |
|
|
names = [] |
|
|
for n, m in model.named_sublayers(): |
|
|
if isinstance(m, (nn.BatchNorm2D, nn.SyncBatchNorm)): |
|
|
assert hasattr(m, '_mean'), f'assert {m} has _mean' |
|
|
assert hasattr(m, '_variance'), f'assert {m} has _variance' |
|
|
running_mean = f'{n}._mean' |
|
|
running_var = f'{n}._variance' |
|
|
names.extend([running_mean, running_var]) |
|
|
|
|
|
return names |
|
|
|