| import torch |
| import torch.nn as nn |
| from collections import OrderedDict |
|
|
|
|
| def group_weight(module): |
| |
| |
| group_decay = [] |
| group_no_decay = [] |
| for m in module.modules(): |
| if isinstance(m, nn.Linear): |
| group_decay.append(m.weight) |
| if m.bias is not None: |
| group_no_decay.append(m.bias) |
| elif isinstance(m, nn.modules.conv._ConvNd): |
| group_decay.append(m.weight) |
| if m.bias is not None: |
| group_no_decay.append(m.bias) |
| elif isinstance(m, nn.modules.batchnorm._BatchNorm): |
| if m.weight is not None: |
| group_no_decay.append(m.weight) |
| if m.bias is not None: |
| group_no_decay.append(m.bias) |
| elif isinstance(m, nn.GroupNorm): |
| if m.weight is not None: |
| group_no_decay.append(m.weight) |
| if m.bias is not None: |
| group_no_decay.append(m.bias) |
|
|
| assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) |
| return [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] |
|
|
|
|
| def adjust_learning_rate(optimizer, args): |
| if args.cur_iter < args.warmup_iters: |
| frac = args.cur_iter / args.warmup_iters |
| step = args.lr - args.warmup_lr |
| args.running_lr = args.warmup_lr + step * frac |
| else: |
| frac = (float(args.cur_iter) - args.warmup_iters) / (args.max_iters - args.warmup_iters) |
| scale_running_lr = max((1. - frac), 0.) ** args.lr_pow |
| args.running_lr = args.lr * scale_running_lr |
|
|
| for param_group in optimizer.param_groups: |
| param_group['lr'] = args.running_lr |
|
|
|
|
| def save_model(net, path, args): |
| state_dict = OrderedDict({ |
| 'args': args.__dict__, |
| 'kwargs': { |
| 'backbone': net.backbone, |
| 'use_rnn': net.use_rnn, |
| }, |
| 'state_dict': net.state_dict(), |
| }) |
| torch.save(state_dict, path) |
|
|
|
|
| def load_trained_model(Net, path): |
| state_dict = torch.load(path, map_location='cpu') |
| |
| |
| if 'kwargs' in state_dict: |
| |
| net = Net(**state_dict['kwargs']) |
| elif 'backbone' in state_dict: |
| |
| backbone = state_dict.get('backbone', 'resnet50') |
| net = Net(backbone, use_rnn=True) |
| else: |
| |
| net = Net('resnet50', use_rnn=True) |
| |
| net.load_state_dict(state_dict['state_dict']) |
| return net |
|
|