File size: 7,750 Bytes
dfd1909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from Model.ModelGan import ModelGan
from HParams import HParams
class OptimizerControlGan:
def __init__(self, model:ModelGan = None) -> None:
self.h_params = HParams()
self.generator_optimizer:Optimizer = None
self.discriminator_optimizer:Optimizer = None
self.generator_lr_scheduler:_LRScheduler = None
self.discriminator_lr_scheduler:_LRScheduler = None
self.scheduler_config:dict = self.h_params.train.scheduler
self.num_gen_lr_scheduler_step:int = 0
self.num_dis_lr_scheduler_step:int = 0
if model is not None:
self.set_optimizer(model)
self.set_lr_scheduler()
def set_optimizer(self,model:ModelGan) -> None:
generator_optimizer_name:str = self.h_params.train.optimizer["generator_name"]
generator_optimizer_config:dict = self.h_params.train.optimizer["generator_config"]
generator_optimizer_config["params"] = model.generator.parameters()
generator_optimizer_config['lr'] = float(generator_optimizer_config['lr'])
generator_optimizer_config['eps'] = float(generator_optimizer_config['eps'])
self.generator_optimizer = self.get_optimizer(generator_optimizer_name,generator_optimizer_config)
discriminator_optimizer_name:str = self.h_params.train.optimizer["generator_name"]
discriminator_optimizer_config:dict = self.h_params.train.optimizer["discriminator_config"]
discriminator_optimizer_config["params"] = model.discriminator.parameters()
discriminator_optimizer_config['lr'] = float(discriminator_optimizer_config['lr'])
discriminator_optimizer_config['eps'] = float(discriminator_optimizer_config['eps'])
self.discriminator_optimizer = self.get_optimizer(discriminator_optimizer_name,discriminator_optimizer_config)
def get_optimizer(self,optimizer_name:str, optimizer_config_dict:dict) -> Optimizer:
if optimizer_name == "Adam":
return torch.optim.Adam(**optimizer_config_dict)
def optimizer_state_dict(self) -> dict:
return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()}
def optimizer_load_state_dict(self, state_dict) -> None:
self.generator_optimizer.load_state_dict(state_dict['generator'])
self.discriminator_optimizer.load_state_dict(state_dict['discriminator'])
def lr_scheduler_state_dict(self) -> dict:
state_dict_of_lr_scheduler:dict = dict()
if self.generator_lr_scheduler is not None:
state_dict_of_lr_scheduler['generator'] = self.generator_lr_scheduler.state_dict()
if self.discriminator_lr_scheduler is not None:
state_dict_of_lr_scheduler['discriminator'] = self.discriminator_lr_scheduler.state_dict()
return state_dict_of_lr_scheduler
def lr_scheduler_load_state_dict(self, state_dict:dict) -> None:
if self.generator_lr_scheduler is not None:
self.generator_lr_scheduler.load_state_dict(state_dict['generator'])
if self.discriminator_lr_scheduler is not None:
self.discriminator_lr_scheduler.load_state_dict(state_dict['discriminator'])
def set_lr_scheduler(self) -> None:
pass
def lr_scheduler_step(self,interval_type:str = "step",args = None) -> None:
self.gen_lr_scheduler_step(interval_type,args)
self.disc_lr_scheduler_step(interval_type,args)
def gen_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None:
if ((self.num_gen_lr_scheduler_step) % self.scheduler_config["generator_config"]["frequency"]) != 0:
return
if interval_type != self.scheduler_config["generator_config"]["interval"]:
return
if self.generator_lr_scheduler is not None:
self.generator_lr_scheduler.step()
self.num_gen_lr_scheduler_step += 1
def disc_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None:
if ((self.num_dis_lr_scheduler_step) % self.scheduler_config["discriminator_config"]["frequency"]) != 0:
return
if interval_type != self.scheduler_config["discriminator_config"]["interval"]:
return
if self.discriminator_lr_scheduler is not None:
self.discriminator_lr_scheduler.step()
self.num_dis_lr_scheduler_step += 1
def get_lr(self) -> float:
return self.generator_optimizer.param_groups[0]["lr"]
'''
def __init__(self,model:MelGan,h_params:HParams) -> None:
self.h_params = h_params
self.discriminator_optimizer = torch.optim.Adam(
model.discriminator.parameters(),
lr=h_params.train.lr,
weight_decay=h_params.train.weight_decay
)
self.generator_optimizer = torch.optim.Adam(
model.generator.parameters(),
lr=h_params.train.lr,
weight_decay=h_params.train.weight_decay
)
self.discriminator_state_name = "discriminator"
self.generator_state_name = "generator"
self.current_state = self.generator_state_name
self.lr_scheduler_discriminator = self.get_lr_scheduler(self.discriminator_optimizer)
self.lr_scheduler_generator = self.get_lr_scheduler(self.generator_optimizer)
def get_lr_scheduler(self,optimizer):
if self.h_params.train.optimizer_name == "ReduceLROnPlateau":
return torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=self.h_params.train.lr_decay_gamma,
patience=self.h_params.train.lr_decay_patience,
cooldown=10,
)
elif self.h_params.train.optimizer_name == "StepLR":
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.h_params.train.lr_scheduler_step_size, gamma=self.h_params.train.lr_decay_factor)
def zero_grad(self):
self.discriminator_optimizer.zero_grad()
self.generator_optimizer.zero_grad()
def step(self):
if self.current_state == self.discriminator_state_name:
self.discriminator_optimizer.step()
elif self.current_state == self.generator_state_name:
self.generator_optimizer.step()
def state_dict(self):
return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()}
def load_state_dict(self, state_dict_dict):
self.generator_optimizer.load_state_dict(state_dict_dict["generator"])
self.discriminator_optimizer.load_state_dict(state_dict_dict["discriminator"])
def lr_scheduler_step(self,vaild_loss=None):
if self.h_params.train.optimizer_name == "ReduceLROnPlateau":
if self.current_state == self.discriminator_state_name:
self.lr_scheduler_discriminator.step(vaild_loss)
elif self.current_state == self.generator_state_name:
self.lr_scheduler_generator.step(vaild_loss)
elif self.h_params.train.optimizer_name == "StepLR":
if self.current_state == self.discriminator_state_name:
self.lr_scheduler_discriminator.step()
elif self.current_state == self.generator_state_name:
self.lr_scheduler_generator.step()
'''
|