Spaces:
Sleeping
Sleeping
| import torch | |
| def get_optimizer( | |
| optimizer_name: str, latents: torch.Tensor, lr: float, nesterov: bool | |
| ): | |
| if optimizer_name == "adam": | |
| optimizer = torch.optim.Adam([latents], lr=lr, eps=1e-2) | |
| elif optimizer_name == "sgd": | |
| optimizer = torch.optim.SGD([latents], lr=lr, nesterov=nesterov, momentum=0.9) | |
| elif optimizer_name == "lbfgs": | |
| optimizer = torch.optim.LBFGS( | |
| [latents], | |
| lr=lr, | |
| max_iter=10, | |
| history_size=3, | |
| line_search_fn="strong_wolfe", | |
| ) | |
| else: | |
| raise ValueError(f"Unknown optimizer {optimizer_name}") | |
| return optimizer | |