import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from Model.ModelGan import ModelGan from HParams import HParams class OptimizerControlGan: def __init__(self, model:ModelGan = None) -> None: self.h_params = HParams() self.generator_optimizer:Optimizer = None self.discriminator_optimizer:Optimizer = None self.generator_lr_scheduler:_LRScheduler = None self.discriminator_lr_scheduler:_LRScheduler = None self.scheduler_config:dict = self.h_params.train.scheduler self.num_gen_lr_scheduler_step:int = 0 self.num_dis_lr_scheduler_step:int = 0 if model is not None: self.set_optimizer(model) self.set_lr_scheduler() def set_optimizer(self,model:ModelGan) -> None: generator_optimizer_name:str = self.h_params.train.optimizer["generator_name"] generator_optimizer_config:dict = self.h_params.train.optimizer["generator_config"] generator_optimizer_config["params"] = model.generator.parameters() generator_optimizer_config['lr'] = float(generator_optimizer_config['lr']) generator_optimizer_config['eps'] = float(generator_optimizer_config['eps']) self.generator_optimizer = self.get_optimizer(generator_optimizer_name,generator_optimizer_config) discriminator_optimizer_name:str = self.h_params.train.optimizer["generator_name"] discriminator_optimizer_config:dict = self.h_params.train.optimizer["discriminator_config"] discriminator_optimizer_config["params"] = model.discriminator.parameters() discriminator_optimizer_config['lr'] = float(discriminator_optimizer_config['lr']) discriminator_optimizer_config['eps'] = float(discriminator_optimizer_config['eps']) self.discriminator_optimizer = self.get_optimizer(discriminator_optimizer_name,discriminator_optimizer_config) def get_optimizer(self,optimizer_name:str, optimizer_config_dict:dict) -> Optimizer: if optimizer_name == "Adam": return torch.optim.Adam(**optimizer_config_dict) def optimizer_state_dict(self) -> dict: return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()} def optimizer_load_state_dict(self, state_dict) -> None: self.generator_optimizer.load_state_dict(state_dict['generator']) self.discriminator_optimizer.load_state_dict(state_dict['discriminator']) def lr_scheduler_state_dict(self) -> dict: state_dict_of_lr_scheduler:dict = dict() if self.generator_lr_scheduler is not None: state_dict_of_lr_scheduler['generator'] = self.generator_lr_scheduler.state_dict() if self.discriminator_lr_scheduler is not None: state_dict_of_lr_scheduler['discriminator'] = self.discriminator_lr_scheduler.state_dict() return state_dict_of_lr_scheduler def lr_scheduler_load_state_dict(self, state_dict:dict) -> None: if self.generator_lr_scheduler is not None: self.generator_lr_scheduler.load_state_dict(state_dict['generator']) if self.discriminator_lr_scheduler is not None: self.discriminator_lr_scheduler.load_state_dict(state_dict['discriminator']) def set_lr_scheduler(self) -> None: pass def lr_scheduler_step(self,interval_type:str = "step",args = None) -> None: self.gen_lr_scheduler_step(interval_type,args) self.disc_lr_scheduler_step(interval_type,args) def gen_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None: if ((self.num_gen_lr_scheduler_step) % self.scheduler_config["generator_config"]["frequency"]) != 0: return if interval_type != self.scheduler_config["generator_config"]["interval"]: return if self.generator_lr_scheduler is not None: self.generator_lr_scheduler.step() self.num_gen_lr_scheduler_step += 1 def disc_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None: if ((self.num_dis_lr_scheduler_step) % self.scheduler_config["discriminator_config"]["frequency"]) != 0: return if interval_type != self.scheduler_config["discriminator_config"]["interval"]: return if self.discriminator_lr_scheduler is not None: self.discriminator_lr_scheduler.step() self.num_dis_lr_scheduler_step += 1 def get_lr(self) -> float: return self.generator_optimizer.param_groups[0]["lr"] ''' def __init__(self,model:MelGan,h_params:HParams) -> None: self.h_params = h_params self.discriminator_optimizer = torch.optim.Adam( model.discriminator.parameters(), lr=h_params.train.lr, weight_decay=h_params.train.weight_decay ) self.generator_optimizer = torch.optim.Adam( model.generator.parameters(), lr=h_params.train.lr, weight_decay=h_params.train.weight_decay ) self.discriminator_state_name = "discriminator" self.generator_state_name = "generator" self.current_state = self.generator_state_name self.lr_scheduler_discriminator = self.get_lr_scheduler(self.discriminator_optimizer) self.lr_scheduler_generator = self.get_lr_scheduler(self.generator_optimizer) def get_lr_scheduler(self,optimizer): if self.h_params.train.optimizer_name == "ReduceLROnPlateau": return torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=self.h_params.train.lr_decay_gamma, patience=self.h_params.train.lr_decay_patience, cooldown=10, ) elif self.h_params.train.optimizer_name == "StepLR": return torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.h_params.train.lr_scheduler_step_size, gamma=self.h_params.train.lr_decay_factor) def zero_grad(self): self.discriminator_optimizer.zero_grad() self.generator_optimizer.zero_grad() def step(self): if self.current_state == self.discriminator_state_name: self.discriminator_optimizer.step() elif self.current_state == self.generator_state_name: self.generator_optimizer.step() def state_dict(self): return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()} def load_state_dict(self, state_dict_dict): self.generator_optimizer.load_state_dict(state_dict_dict["generator"]) self.discriminator_optimizer.load_state_dict(state_dict_dict["discriminator"]) def lr_scheduler_step(self,vaild_loss=None): if self.h_params.train.optimizer_name == "ReduceLROnPlateau": if self.current_state == self.discriminator_state_name: self.lr_scheduler_discriminator.step(vaild_loss) elif self.current_state == self.generator_state_name: self.lr_scheduler_generator.step(vaild_loss) elif self.h_params.train.optimizer_name == "StepLR": if self.current_state == self.discriminator_state_name: self.lr_scheduler_discriminator.step() elif self.current_state == self.generator_state_name: self.lr_scheduler_generator.step() '''