File size: 7,750 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from Model.ModelGan import ModelGan
from HParams import HParams

class OptimizerControlGan:
    def __init__(self, model:ModelGan = None) -> None:
        self.h_params = HParams()

        self.generator_optimizer:Optimizer = None
        self.discriminator_optimizer:Optimizer = None

        self.generator_lr_scheduler:_LRScheduler = None
        self.discriminator_lr_scheduler:_LRScheduler = None

        self.scheduler_config:dict = self.h_params.train.scheduler
        self.num_gen_lr_scheduler_step:int = 0
        self.num_dis_lr_scheduler_step:int = 0

        if model is not None:
            self.set_optimizer(model)
            self.set_lr_scheduler()

    def set_optimizer(self,model:ModelGan) -> None:
        generator_optimizer_name:str = self.h_params.train.optimizer["generator_name"]

        generator_optimizer_config:dict = self.h_params.train.optimizer["generator_config"]
        generator_optimizer_config["params"] = model.generator.parameters()
        generator_optimizer_config['lr'] = float(generator_optimizer_config['lr'])
        generator_optimizer_config['eps'] = float(generator_optimizer_config['eps'])

        self.generator_optimizer = self.get_optimizer(generator_optimizer_name,generator_optimizer_config)
        
        discriminator_optimizer_name:str = self.h_params.train.optimizer["generator_name"]

        discriminator_optimizer_config:dict = self.h_params.train.optimizer["discriminator_config"]
        discriminator_optimizer_config["params"] = model.discriminator.parameters()
        discriminator_optimizer_config['lr'] = float(discriminator_optimizer_config['lr'])
        discriminator_optimizer_config['eps'] = float(discriminator_optimizer_config['eps'])

        self.discriminator_optimizer =  self.get_optimizer(discriminator_optimizer_name,discriminator_optimizer_config)
    
    def get_optimizer(self,optimizer_name:str, optimizer_config_dict:dict) -> Optimizer:
        if optimizer_name == "Adam":
            return torch.optim.Adam(**optimizer_config_dict)
    
    def optimizer_state_dict(self) -> dict:
        return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()}
    
    def optimizer_load_state_dict(self, state_dict) -> None:
        self.generator_optimizer.load_state_dict(state_dict['generator'])
        self.discriminator_optimizer.load_state_dict(state_dict['discriminator'])
        
    
    def lr_scheduler_state_dict(self) -> dict:
        state_dict_of_lr_scheduler:dict = dict()

        if self.generator_lr_scheduler is not None:
            state_dict_of_lr_scheduler['generator'] = self.generator_lr_scheduler.state_dict()

        if self.discriminator_lr_scheduler is not None:
            state_dict_of_lr_scheduler['discriminator'] = self.discriminator_lr_scheduler.state_dict()
        
        return state_dict_of_lr_scheduler
    
    def lr_scheduler_load_state_dict(self, state_dict:dict) -> None:
        if self.generator_lr_scheduler is not None:
            self.generator_lr_scheduler.load_state_dict(state_dict['generator'])

        if self.discriminator_lr_scheduler is not None:
            self.discriminator_lr_scheduler.load_state_dict(state_dict['discriminator'])
    
    def set_lr_scheduler(self) -> None:
        pass

    def lr_scheduler_step(self,interval_type:str = "step",args = None) -> None:
        self.gen_lr_scheduler_step(interval_type,args)
        self.disc_lr_scheduler_step(interval_type,args)
    
    def gen_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None:
        if ((self.num_gen_lr_scheduler_step) % self.scheduler_config["generator_config"]["frequency"]) != 0:
            return
        if interval_type != self.scheduler_config["generator_config"]["interval"]:
            return
        
        if self.generator_lr_scheduler is not None:
            self.generator_lr_scheduler.step()
        
        self.num_gen_lr_scheduler_step += 1

    def disc_lr_scheduler_step(self,interval_type:str = "step",args = None) -> None:
        if ((self.num_dis_lr_scheduler_step) % self.scheduler_config["discriminator_config"]["frequency"]) != 0:
            return
        if interval_type != self.scheduler_config["discriminator_config"]["interval"]:
            return
        
        if self.discriminator_lr_scheduler is not None:
            self.discriminator_lr_scheduler.step()
        
        self.num_dis_lr_scheduler_step += 1
    
    def get_lr(self) -> float:
        return self.generator_optimizer.param_groups[0]["lr"]

'''
    def __init__(self,model:MelGan,h_params:HParams) -> None:
        self.h_params = h_params
        self.discriminator_optimizer = torch.optim.Adam(
                            model.discriminator.parameters(),
                            lr=h_params.train.lr, 
                            weight_decay=h_params.train.weight_decay
                            )
        self.generator_optimizer = torch.optim.Adam(
                            model.generator.parameters(),
                            lr=h_params.train.lr, 
                            weight_decay=h_params.train.weight_decay
                            )
        
        self.discriminator_state_name = "discriminator"
        self.generator_state_name = "generator"
        self.current_state = self.generator_state_name

        self.lr_scheduler_discriminator = self.get_lr_scheduler(self.discriminator_optimizer)
        
        self.lr_scheduler_generator = self.get_lr_scheduler(self.generator_optimizer)

    def get_lr_scheduler(self,optimizer):
        if self.h_params.train.optimizer_name == "ReduceLROnPlateau":
            return torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer,
                            factor=self.h_params.train.lr_decay_gamma,
                            patience=self.h_params.train.lr_decay_patience,
                            cooldown=10,
                            )
        elif self.h_params.train.optimizer_name == "StepLR":
            return torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.h_params.train.lr_scheduler_step_size, gamma=self.h_params.train.lr_decay_factor)
    
    def zero_grad(self):
        self.discriminator_optimizer.zero_grad()
        self.generator_optimizer.zero_grad()
    
    def step(self):
        if self.current_state == self.discriminator_state_name:
            self.discriminator_optimizer.step()
        elif self.current_state == self.generator_state_name:
            self.generator_optimizer.step()
    
    def state_dict(self):
        return {"generator": self.generator_optimizer.state_dict(),"discriminator": self.discriminator_optimizer.state_dict()}
    
    def load_state_dict(self, state_dict_dict):
        self.generator_optimizer.load_state_dict(state_dict_dict["generator"])
        self.discriminator_optimizer.load_state_dict(state_dict_dict["discriminator"])
    
    def lr_scheduler_step(self,vaild_loss=None):
        if self.h_params.train.optimizer_name == "ReduceLROnPlateau":
            if self.current_state == self.discriminator_state_name:
                self.lr_scheduler_discriminator.step(vaild_loss)
            elif self.current_state == self.generator_state_name:
                self.lr_scheduler_generator.step(vaild_loss)
        elif self.h_params.train.optimizer_name == "StepLR":
            if self.current_state == self.discriminator_state_name:
                self.lr_scheduler_discriminator.step()
            elif self.current_state == self.generator_state_name:
                self.lr_scheduler_generator.step()
'''