Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os, sys | |
| import os.path as osp | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.optim import Optimizer | |
| from functools import reduce | |
| from torch.optim import AdamW | |
| class MultiOptimizer: | |
| def __init__(self, optimizers={}, schedulers={}): | |
| self.optimizers = optimizers | |
| self.schedulers = schedulers | |
| self.keys = list(optimizers.keys()) | |
| self.param_groups = reduce( | |
| lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()] | |
| ) | |
| def state_dict(self): | |
| state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys] | |
| return state_dicts | |
| def scheduler_state_dict(self): | |
| state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys] | |
| return state_dicts | |
| def load_state_dict(self, state_dict): | |
| for key, val in state_dict: | |
| try: | |
| self.optimizers[key].load_state_dict(val) | |
| except: | |
| print("Unloaded %s" % key) | |
| def load_scheduler_state_dict(self, state_dict): | |
| for key, val in state_dict: | |
| try: | |
| self.schedulers[key].load_state_dict(val) | |
| except: | |
| print("Unloaded %s" % key) | |
| def step(self, key=None, scaler=None): | |
| keys = [key] if key is not None else self.keys | |
| _ = [self._step(key, scaler) for key in keys] | |
| def _step(self, key, scaler=None): | |
| if scaler is not None: | |
| scaler.step(self.optimizers[key]) | |
| scaler.update() | |
| else: | |
| self.optimizers[key].step() | |
| def zero_grad(self, key=None): | |
| if key is not None: | |
| self.optimizers[key].zero_grad() | |
| else: | |
| _ = [self.optimizers[key].zero_grad() for key in self.keys] | |
| def scheduler(self, *args, key=None): | |
| if key is not None: | |
| self.schedulers[key].step(*args) | |
| else: | |
| _ = [self.schedulers[key].step_batch(*args) for key in self.keys] | |
| def define_scheduler(optimizer, params): | |
| scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"]) | |
| return scheduler | |
| def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"): | |
| optim = {} | |
| for key, model in model_dict.items(): | |
| model_parameters = model.parameters() | |
| parameters_names = [] | |
| parameters_names.append( | |
| [name_param_pair[0] for name_param_pair in model.named_parameters()] | |
| ) | |
| if type == "AdamW": | |
| optim[key] = AdamW( | |
| model_parameters, | |
| lr=lr, | |
| betas=(0.9, 0.98), | |
| eps=1e-9, | |
| weight_decay=0.1, | |
| ) | |
| else: | |
| raise ValueError("Unknown optimizer type: %s" % type) | |
| schedulers = dict( | |
| [ | |
| (key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996)) | |
| for key, opt in optim.items() | |
| ] | |
| ) | |
| multi_optim = MultiOptimizer(optim, schedulers) | |
| return multi_optim | |