| import torch | |
| def get_optimizer(parameters, optimizer_name, learning_rate, **kwargs): | |
| optimizer_name = optimizer_name.lower() | |
| optimizer_map = { | |
| "adam": torch.optim.Adam, | |
| "sgd": torch.optim.SGD, | |
| "adagrad": torch.optim.Adagrad, | |
| "adadelta": torch.optim.Adadelta, | |
| "rmsprop": torch.optim.RMSprop, | |
| "adamw": torch.optim.AdamW, | |
| } | |
| if optimizer_name not in optimizer_map: | |
| raise ValueError( | |
| f"Invalid optimizer name: {optimizer_name}. " | |
| f"Valid options: {list(optimizer_map.keys())}" | |
| ) | |
| defaults = { | |
| "adam": {"weight_decay": 0.0, "amsgrad": False}, | |
| "adamw": {"weight_decay": 0.01}, | |
| "sgd": {"momentum": 0.9, "nesterov": True}, | |
| "rmsprop": {"alpha": 0.99, "momentum": 0.0}, | |
| }.get(optimizer_name, {}) | |
| final_params = {**defaults, **kwargs} | |
| return optimizer_map[optimizer_name](parameters, lr=learning_rate, **final_params) | |