| | |
| | |
| |
|
| | """ |
| | @Author : Peike Li |
| | @Contact : peike.li@yahoo.com |
| | @File : schp.py |
| | @Time : 4/8/19 2:11 PM |
| | @Desc : |
| | @License : This source code is licensed under the license found in the |
| | LICENSE file in the root directory of this source tree. |
| | """ |
| |
|
| | import os |
| | import torch |
| | import modules |
| |
|
| | def moving_average(net1, net2, alpha=1): |
| | for param1, param2 in zip(net1.parameters(), net2.parameters()): |
| | param1.data *= (1.0 - alpha) |
| | param1.data += param2.data * alpha |
| |
|
| |
|
| | def _check_bn(module, flag): |
| | if issubclass(module.__class__, modules.bn.InPlaceABNSync): |
| | flag[0] = True |
| |
|
| |
|
| | def check_bn(model): |
| | flag = [False] |
| | model.apply(lambda module: _check_bn(module, flag)) |
| | return flag[0] |
| |
|
| |
|
| | def reset_bn(module): |
| | if issubclass(module.__class__, modules.bn.InPlaceABNSync): |
| | module.running_mean = torch.zeros_like(module.running_mean) |
| | module.running_var = torch.ones_like(module.running_var) |
| |
|
| |
|
| | def _get_momenta(module, momenta): |
| | if issubclass(module.__class__, modules.bn.InPlaceABNSync): |
| | momenta[module] = module.momentum |
| |
|
| |
|
| | def _set_momenta(module, momenta): |
| | if issubclass(module.__class__, modules.bn.InPlaceABNSync): |
| | module.momentum = momenta[module] |
| |
|
| |
|
| | def bn_re_estimate(loader, model): |
| | if not check_bn(model): |
| | print('No batch norm layer detected') |
| | return |
| | model.train() |
| | momenta = {} |
| | model.apply(reset_bn) |
| | model.apply(lambda module: _get_momenta(module, momenta)) |
| | n = 0 |
| | for i_iter, batch in enumerate(loader): |
| | images, labels, _ = batch |
| | b = images.data.size(0) |
| | momentum = b / (n + b) |
| | for module in momenta.keys(): |
| | module.momentum = momentum |
| | model(images) |
| | n += b |
| | model.apply(lambda module: _set_momenta(module, momenta)) |
| |
|
| |
|
| | def save_schp_checkpoint(states, is_best_parsing, output_dir, filename='schp_checkpoint.pth.tar'): |
| | save_path = os.path.join(output_dir, filename) |
| | if os.path.exists(save_path): |
| | os.remove(save_path) |
| | torch.save(states, save_path) |
| | if is_best_parsing and 'state_dict' in states: |
| | best_save_path = os.path.join(output_dir, 'model_parsing_best.pth.tar') |
| | if os.path.exists(best_save_path): |
| | os.remove(best_save_path) |
| | torch.save(states, best_save_path) |
| |
|