Spaces:
Sleeping
Sleeping
| from torch.optim import SGD | |
| from torch.optim import Adam | |
| from torch.optim import ASGD | |
| from torch.optim import Adamax | |
| from torch.optim import Adadelta | |
| from torch.optim import Adagrad | |
| from torch.optim import RMSprop | |
| key2opt = { | |
| 'sgd': SGD, | |
| 'adam': Adam, | |
| 'asgd': ASGD, | |
| 'adamax': Adamax, | |
| 'adadelta': Adadelta, | |
| 'adagrad': Adagrad, | |
| 'rmsprop': RMSprop, | |
| } | |
| def get_optimizer(optimizer_name=None): | |
| if optimizer_name is None: | |
| print("Using default 'SGD' optimizer") | |
| return SGD | |
| else: | |
| if optimizer_name not in key2opt: | |
| raise NotImplementedError(f"Optimizer '{optimizer_name}' not implemented") | |
| print(f"Using optimizer: '{optimizer_name}'") | |
| return key2opt[optimizer_name] | |