DesignIA / horizonnet /misc /utils.py
agerhund's picture
Upload 43 files
289f097 verified
import torch
import torch.nn as nn
from collections import OrderedDict
def group_weight(module):
# Group module parameters into two group
# One need weight_decay and the other doesn't
group_decay = []
group_no_decay = []
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.conv._ConvNd):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.batchnorm._BatchNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.GroupNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
return [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
def adjust_learning_rate(optimizer, args):
if args.cur_iter < args.warmup_iters:
frac = args.cur_iter / args.warmup_iters
step = args.lr - args.warmup_lr
args.running_lr = args.warmup_lr + step * frac
else:
frac = (float(args.cur_iter) - args.warmup_iters) / (args.max_iters - args.warmup_iters)
scale_running_lr = max((1. - frac), 0.) ** args.lr_pow
args.running_lr = args.lr * scale_running_lr
for param_group in optimizer.param_groups:
param_group['lr'] = args.running_lr
def save_model(net, path, args):
state_dict = OrderedDict({
'args': args.__dict__,
'kwargs': {
'backbone': net.backbone,
'use_rnn': net.use_rnn,
},
'state_dict': net.state_dict(),
})
torch.save(state_dict, path)
def load_trained_model(Net, path):
state_dict = torch.load(path, map_location='cpu')
# Compatibilidad con diferentes formatos de checkpoint
if 'kwargs' in state_dict:
# Formato antiguo: con kwargs
net = Net(**state_dict['kwargs'])
elif 'backbone' in state_dict:
# Formato de entrenamiento: con backbone directo
backbone = state_dict.get('backbone', 'resnet50')
net = Net(backbone, use_rnn=True)
else:
# Fallback: usar valores por defecto
net = Net('resnet50', use_rnn=True)
net.load_state_dict(state_dict['state_dict'])
return net