Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from collections import OrderedDict | |
| from mmengine.dist import get_dist_info | |
| from mmengine.hooks import Hook | |
| from torch import nn | |
| from mmdet.registry import HOOKS | |
| from mmdet.utils import all_reduce_dict | |
| def get_norm_states(module: nn.Module) -> OrderedDict: | |
| """Get the state_dict of batch norms in the module.""" | |
| async_norm_states = OrderedDict() | |
| for name, child in module.named_modules(): | |
| if isinstance(child, nn.modules.batchnorm._NormBase): | |
| for k, v in child.state_dict().items(): | |
| async_norm_states['.'.join([name, k])] = v | |
| return async_norm_states | |
| class SyncNormHook(Hook): | |
| """Synchronize Norm states before validation, currently used in YOLOX.""" | |
| def before_val_epoch(self, runner): | |
| """Synchronizing norm.""" | |
| module = runner.model | |
| _, world_size = get_dist_info() | |
| if world_size == 1: | |
| return | |
| norm_states = get_norm_states(module) | |
| if len(norm_states) == 0: | |
| return | |
| # TODO: use `all_reduce_dict` in mmengine | |
| norm_states = all_reduce_dict(norm_states, op='mean') | |
| module.load_state_dict(norm_states, strict=False) | |