| """ |
| Optimizer |
| |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) |
| Please cite our work if the code is helpful to you. |
| """ |
|
|
| import copy |
| import torch |
| from pointcept.utils.logger import get_root_logger |
| from pointcept.utils.registry import Registry |
|
|
| OPTIMIZERS = Registry("optimizers") |
|
|
|
|
| OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") |
| OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") |
| OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") |
|
|
|
|
| def build_optimizer(cfg, model, param_dicts=None): |
| cfg = copy.deepcopy(cfg) |
| if param_dicts is None: |
| cfg.params = model.parameters() |
| else: |
| cfg.params = [dict(names=[], params=[], lr=cfg.lr)] |
| for i in range(len(param_dicts)): |
| param_group = dict(names=[], params=[]) |
| if "lr" in param_dicts[i].keys(): |
| param_group["lr"] = param_dicts[i].lr |
| if "momentum" in param_dicts[i].keys(): |
| param_group["momentum"] = param_dicts[i].momentum |
| if "weight_decay" in param_dicts[i].keys(): |
| param_group["weight_decay"] = param_dicts[i].weight_decay |
| cfg.params.append(param_group) |
|
|
| for n, p in model.named_parameters(): |
| flag = False |
| for i in range(len(param_dicts)): |
| if param_dicts[i].keyword in n: |
| cfg.params[i + 1]["names"].append(n) |
| cfg.params[i + 1]["params"].append(p) |
| flag = True |
| break |
| if not flag: |
| cfg.params[0]["names"].append(n) |
| cfg.params[0]["params"].append(p) |
|
|
| logger = get_root_logger() |
| for i in range(len(cfg.params)): |
| param_names = cfg.params[i].pop("names") |
| message = "" |
| for key in cfg.params[i].keys(): |
| if key != "params": |
| message += f" {key}: {cfg.params[i][key]};" |
| logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") |
| return OPTIMIZERS.build(cfg=cfg) |
|
|