Spaces:
Configuration error
Configuration error
| import torch | |
| from lib.utils.optimizer.radam import RAdam | |
| _optimizer_factory = { | |
| 'adam': torch.optim.Adam, | |
| 'radam': RAdam, | |
| 'sgd': torch.optim.SGD | |
| } | |
| def make_optimizer(cfg, net, lr=None, weight_decay=None): | |
| params = [] | |
| lr = cfg.train.lr if lr is None else lr | |
| weight_decay = cfg.train.weight_decay if weight_decay is None else weight_decay | |
| for key, value in net.named_parameters(): | |
| if not value.requires_grad: | |
| continue | |
| params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] | |
| if 'adam' in cfg.train.optim: | |
| optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay) | |
| else: | |
| optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9) | |
| return optimizer | |