| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | from collections import defaultdict |
| | import typing as tp |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set: |
| | names: set = set() |
| | for (name, sub_module) in module.named_modules(): |
| | if name == '': |
| | buffer_names = module._non_persistent_buffers_set |
| | buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name |
| | for buff_name in buffer_names} |
| | names.update(buffer_names) |
| | else: |
| | sub_name = f"{root}.{name}" if len(root) > 0 else name |
| | sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name) |
| | names.update(sub_buffer_names) |
| | return names |
| |
|
| |
|
| | def _get_named_tensors(module: nn.Module): |
| | non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module) |
| | named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers() |
| | if name not in non_persistent_buffers_set] |
| | named_parameters = list(module.named_parameters()) |
| | return named_parameters + named_buffers |
| |
|
| |
|
| | class ModuleDictEMA: |
| | """Exponential Moving Average over a nn.ModuleDict. |
| | |
| | You can switch to the EMA weights temporarily. |
| | """ |
| | def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, |
| | unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'): |
| | self.decay = decay |
| | self.module_dict = module_dict |
| | self.state: dict = defaultdict(dict) |
| | self.count = 0 |
| | self.device = device |
| | self.unbias = unbias |
| | self._init() |
| |
|
| | def _init(self): |
| | for module_name, module in self.module_dict.items(): |
| | for key, val in _get_named_tensors(module): |
| | if not val.is_floating_point(): |
| | continue |
| | device = self.device or val.device |
| | if key not in self.state[module_name]: |
| | self.state[module_name][key] = val.detach().to(device, copy=True) |
| |
|
| | def step(self): |
| | if self.unbias: |
| | self.count = self.count * self.decay + 1 |
| | w = 1 / self.count |
| | else: |
| | w = 1 - self.decay |
| | for module_name, module in self.module_dict.items(): |
| | for key, val in _get_named_tensors(module): |
| | if not val.is_floating_point(): |
| | continue |
| | device = self.device or val.device |
| | self.state[module_name][key].mul_(1 - w) |
| | self.state[module_name][key].add_(val.detach().to(device), alpha=w) |
| |
|
| | def state_dict(self): |
| | return {'state': self.state, 'count': self.count} |
| |
|
| | def load_state_dict(self, state): |
| | self.count = state['count'] |
| | for module_name, module in state['state'].items(): |
| | for key, val in module.items(): |
| | self.state[module_name][key].copy_(val) |
| |
|