File size: 2,611 Bytes
c5f4ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch.optim as optim
import torch.nn as nn
import torch
import itertools


def add_full_model_gradient_clipping(optim, clip_norm_val):

    class FullModelGradientClippingOptimizer(optim):
        def step(self, closure=None):
            all_params = itertools.chain(*[x["params"] for x in self.param_groups])
            torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
            super().step(closure=closure)

    return FullModelGradientClippingOptimizer


class Optimizer(object):
    def __init__(self, models, training_params, sep_lr=None, sep_params=None, gradient_clip=0):
        
        params = []
        for model in models:
            if isinstance(model, nn.Parameter):
                params += [model]
            else:
                params += list(model.parameters())
        if sep_lr is not None:
            print(sep_lr)
            add_params = []
            for model in sep_params:
                if isinstance(model, nn.Parameter):
                    add_params += [model]
                else:
                    add_params += list(model.parameters())
            params = [{'params': params}, 
                      {'params': add_params, 'lr': sep_lr}]


        self.lr = training_params['lr']
        self.weight_decay = training_params['weight_decay']
        method = training_params['optimizer']
        
    
        if method == 'SGD':
            self.momentum = training_params['momentum']
            if gradient_clip > 0:
                self.optim = add_full_model_gradient_clipping(optim.SGD, gradient_clip)(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
            else:
                self.optim = optim.SGD(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
        elif method == 'AdamW':
            self.optim = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)
        else:
            raise Exception('{} is not supported'.format(method))

        schedule_name = training_params['lr_schedule']
        schedule_params = training_params['schedule_params']
        if schedule_name == 'CosineAnnealingLR':
            schedule_params['T_max'] = training_params['inter_val'] * 4
        self.lr_schedule = getattr(optim.lr_scheduler, schedule_name)(self.optim, **schedule_params)
        
    def update_lr(self):
        self.lr_schedule.step()
    
    def z_grad(self):
        self.optim.zero_grad()

    def g_step(self):
        self.optim.step()

    def get_lr(self):
        for param_group in self.optim.param_groups:
            return param_group['lr']