| | |
| | |
| |
|
| | import torch.optim as optim |
| | import torch.nn as nn |
| |
|
| |
|
| | def set_optimizer(optimizer_name: str, network: nn.Module, lr: float) -> optim: |
| | """ |
| | Set optimizer. |
| | Args: |
| | optimizer_name (str): criterion name |
| | network (torch.nn.Module): network |
| | lr (float): learning rate |
| | Returns: |
| | torch.optim: optimizer |
| | """ |
| | optimizers = { |
| | 'SGD': optim.SGD, |
| | 'Adadelta': optim.Adadelta, |
| | 'Adam': optim.Adam, |
| | 'RMSprop': optim.RMSprop, |
| | 'RAdam': optim.RAdam |
| | } |
| |
|
| | assert (optimizer_name in optimizers), f"No specified optimizer: {optimizer_name}." |
| |
|
| | _optim = optimizers[optimizer_name] |
| |
|
| | if lr is None: |
| | optimizer = _optim(network.parameters()) |
| | else: |
| | optimizer = _optim(network.parameters(), lr=lr) |
| | return optimizer |
| |
|