import torch import torch.nn as nn from collections import OrderedDict def group_weight(module): # Group module parameters into two group # One need weight_decay and the other doesn't group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.conv._ConvNd): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.batchnorm._BatchNorm): if m.weight is not None: group_no_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.GroupNorm): if m.weight is not None: group_no_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) return [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] def adjust_learning_rate(optimizer, args): if args.cur_iter < args.warmup_iters: frac = args.cur_iter / args.warmup_iters step = args.lr - args.warmup_lr args.running_lr = args.warmup_lr + step * frac else: frac = (float(args.cur_iter) - args.warmup_iters) / (args.max_iters - args.warmup_iters) scale_running_lr = max((1. - frac), 0.) ** args.lr_pow args.running_lr = args.lr * scale_running_lr for param_group in optimizer.param_groups: param_group['lr'] = args.running_lr def save_model(net, path, args): state_dict = OrderedDict({ 'args': args.__dict__, 'kwargs': { 'backbone': net.backbone, 'use_rnn': net.use_rnn, }, 'state_dict': net.state_dict(), }) torch.save(state_dict, path) def load_trained_model(Net, path): state_dict = torch.load(path, map_location='cpu') # Compatibilidad con diferentes formatos de checkpoint if 'kwargs' in state_dict: # Formato antiguo: con kwargs net = Net(**state_dict['kwargs']) elif 'backbone' in state_dict: # Formato de entrenamiento: con backbone directo backbone = state_dict.get('backbone', 'resnet50') net = Net(backbone, use_rnn=True) else: # Fallback: usar valores por defecto net = Net('resnet50', use_rnn=True) net.load_state_dict(state_dict['state_dict']) return net