| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.optim import build_optim_wrapper | |
| class ExampleModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.param1 = nn.Parameter(torch.ones(1)) | |
| self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False) | |
| self.conv2 = nn.Conv2d(4, 2, kernel_size=1) | |
| self.bn = nn.BatchNorm2d(2) | |
| def forward(self, x): | |
| return x | |
| base_lr = 0.01 | |
| base_wd = 0.0001 | |
| momentum = 0.9 | |
| def test_build_optimizer(): | |
| model = ExampleModel() | |
| optim_wrapper_cfg = dict( | |
| type='OptimWrapper', | |
| optimizer=dict( | |
| type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)) | |
| optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) | |
| # test whether optimizer is successfully built from parent. | |
| assert isinstance(optim_wrapper.optimizer, torch.optim.SGD) | |