File size: 2,663 Bytes
289f097 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import torch
import torch.nn as nn
from collections import OrderedDict
def group_weight(module):
# Group module parameters into two group
# One need weight_decay and the other doesn't
group_decay = []
group_no_decay = []
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.conv._ConvNd):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.batchnorm._BatchNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.GroupNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
return [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
def adjust_learning_rate(optimizer, args):
if args.cur_iter < args.warmup_iters:
frac = args.cur_iter / args.warmup_iters
step = args.lr - args.warmup_lr
args.running_lr = args.warmup_lr + step * frac
else:
frac = (float(args.cur_iter) - args.warmup_iters) / (args.max_iters - args.warmup_iters)
scale_running_lr = max((1. - frac), 0.) ** args.lr_pow
args.running_lr = args.lr * scale_running_lr
for param_group in optimizer.param_groups:
param_group['lr'] = args.running_lr
def save_model(net, path, args):
state_dict = OrderedDict({
'args': args.__dict__,
'kwargs': {
'backbone': net.backbone,
'use_rnn': net.use_rnn,
},
'state_dict': net.state_dict(),
})
torch.save(state_dict, path)
def load_trained_model(Net, path):
state_dict = torch.load(path, map_location='cpu')
# Compatibilidad con diferentes formatos de checkpoint
if 'kwargs' in state_dict:
# Formato antiguo: con kwargs
net = Net(**state_dict['kwargs'])
elif 'backbone' in state_dict:
# Formato de entrenamiento: con backbone directo
backbone = state_dict.get('backbone', 'resnet50')
net = Net(backbone, use_rnn=True)
else:
# Fallback: usar valores por defecto
net = Net('resnet50', use_rnn=True)
net.load_state_dict(state_dict['state_dict'])
return net
|