Spaces:
Runtime error
Runtime error
| """ | |
| Optimizer | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| import torch | |
| from pointcept.utils.logger import get_root_logger | |
| from pointcept.utils.registry import Registry | |
| OPTIMIZERS = Registry("optimizers") | |
| OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") | |
| OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") | |
| OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") | |
| def build_optimizer(cfg, model, param_dicts=None): | |
| if param_dicts is None: | |
| cfg.params = model.parameters() | |
| else: | |
| cfg.params = [dict(names=[], params=[], lr=cfg.lr)] | |
| for i in range(len(param_dicts)): | |
| param_group = dict(names=[], params=[]) | |
| if "lr" in param_dicts[i].keys(): | |
| param_group["lr"] = param_dicts[i].lr | |
| if "momentum" in param_dicts[i].keys(): | |
| param_group["momentum"] = param_dicts[i].momentum | |
| if "weight_decay" in param_dicts[i].keys(): | |
| param_group["weight_decay"] = param_dicts[i].weight_decay | |
| cfg.params.append(param_group) | |
| for n, p in model.named_parameters(): | |
| flag = False | |
| for i in range(len(param_dicts)): | |
| if param_dicts[i].keyword in n: | |
| cfg.params[i + 1]["names"].append(n) | |
| cfg.params[i + 1]["params"].append(p) | |
| flag = True | |
| break | |
| if not flag: | |
| cfg.params[0]["names"].append(n) | |
| cfg.params[0]["params"].append(p) | |
| logger = get_root_logger() | |
| for i in range(len(cfg.params)): | |
| param_names = cfg.params[i].pop("names") | |
| message = "" | |
| for key in cfg.params[i].keys(): | |
| if key != "params": | |
| message += f" {key}: {cfg.params[i][key]};" | |
| logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") | |
| return OPTIMIZERS.build(cfg=cfg) | |