FlashSR_One-step_Versatile_Audio_Super-resolution / TorchJaekwon /Train /Optimizer /OptimizerControl.py
| import torch | |
| import torch.nn as nn | |
| from HParams import HParams | |
| class OptimizerControl: | |
| def __init__(self, model:nn.Module = None) -> None: | |
| self.h_params = HParams() | |
| self.optimizer = None | |
| self.lr_scheduler = None | |
| self.scheduler_config = None | |
| self.num_lr_scheduler_step = 0 | |
| if model is not None: | |
| self.set_optimizer(model) | |
| self.set_lr_scheduler() | |
| def set_optimizer(self,model:nn.Module): | |
| optimizer_name:str = self.h_params.train.optimizer["name"] | |
| optimizer_config:dict = self.h_params.train.optimizer["config"] | |
| optimizer_config["params"] = filter(lambda p: p.requires_grad, model.parameters()) | |
| for float_parameter in ['lr','eps']: | |
| if float_parameter in optimizer_config: | |
| optimizer_config[float_parameter] = float(optimizer_config[float_parameter]) | |
| optimizer_class = getattr(torch.optim,optimizer_name,None) | |
| if optimizer_class is not None: | |
| self.optimizer = optimizer_class(**optimizer_config) | |
| def optimizer_step(self): | |
| self.optimizer.step() | |
| def optimizer_zero_grad(self): | |
| self.optimizer.zero_grad() | |
| def optimizer_state_dict(self): | |
| return self.optimizer.state_dict() | |
| def optimizer_load_state_dict(self, state_dict): | |
| self.optimizer.load_state_dict(state_dict) | |
| def lr_scheduler_state_dict(self): | |
| if self.lr_scheduler is not None: | |
| return self.lr_scheduler.state_dict() | |
| def lr_scheduler_load_state_dict(self, state_dict): | |
| if self.lr_scheduler is not None: | |
| self.lr_scheduler.load_state_dict(state_dict) | |
| def set_lr_scheduler(self): | |
| scheduler_dict:dict= getattr(self.h_params.train,'scheduler',None) | |
| if scheduler_dict is not None: | |
| self.scheduler_config = scheduler_dict | |
| scheduler_parameter_dict = scheduler_dict['config'] | |
| scheduler_parameter_dict['optimizer'] = self.optimizer | |
| scheduler_class = getattr(torch.optim.lr_scheduler,scheduler_dict['name'],None) | |
| self.lr_scheduler = scheduler_class(**scheduler_parameter_dict) | |
| def lr_scheduler_step(self,interval_type="step",args = None): | |
| if self.lr_scheduler == None or (self.num_lr_scheduler_step % self.scheduler_config["frequency"]) != 0 or interval_type != self.scheduler_config["interval"]: | |
| return | |
| self.lr_scheduler.step() | |
| self.num_lr_scheduler_step += 1 | |
| def get_lr(self): | |
| return self.optimizer.param_groups[0]["lr"] |