File size: 4,693 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from torch.optim import Optimizer
import math

class _LRScheduler(object):
    def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.step(epoch = last_epoch + 1)
        self.last_epoch = last_epoch

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.
        Arguments:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_lr(self):
        raise NotImplementedError

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

class CosineSchedule(_LRScheduler):

    def __init__(self, optimizer, K):
        self.K = K
        super().__init__(optimizer, -1)

    def cosine(self, base_lr):
        if self.K == 1:
            return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (2-1)))
        return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (self.K-1)))

    def get_lr(self):
        return [self.cosine(base_lr) for base_lr in self.base_lrs]
    
    def get_last_lr(self):
        return self.get_lr()

class CosineAnnealingWarmUp(_LRScheduler):

    def __init__(self, optimizer, warmup_length, T_max = 0, last_epoch = -1):
        self.warmup_length = warmup_length
        self.T_max = T_max
        self.last_epoch = last_epoch

        super().__init__(optimizer, last_epoch)

    def cosine_lr(self, base_lr):

        return base_lr * 0.5 * (1 + math.cos(math.pi * self.last_epoch / self.T_max))

    def warmup_lr(self, base_lr):

        return base_lr * (self.last_epoch + 1) / self.warmup_length

    def get_lr(self):
        if self.last_epoch < self.warmup_length:
            return [self.warmup_lr(base_lr) for base_lr in self.base_lrs]
        else:
            return [self.cosine_lr(base_lr) for base_lr in self.base_lrs]
    
    def get_last_lr(self):
        assert self.T_max > 0, 'CosineAnnealingWarmUp is called with T_max <= 0, Check your code'
        return self.get_lr()

class PatienceSchedule(_LRScheduler):

    def __init__(self, optimizer, patience, factor):
        self.factor = factor      # Factor to reduce the learning rate
        self.patience = patience   # Number of epochs with no improvement
        self.best_loss = float('inf')  # Best loss seen so far
        self.counter = 0            # Counter for patience

        super().__init__(optimizer, -1)

    def step(self, current_loss = None, **kwargs):
        # Some scheduler step function is called with parameter epoch
        # use kwargs to save it and don't do anything to it

        if current_loss is None:
            return 0
        
        # Check if the current loss improved
        if current_loss < self.best_loss:
            self.best_loss = current_loss  # Update the best loss
            self.counter = 0  # Reset counter since we have an improvement
        else:
            
            self.counter += 1  # Increment counter if no improvement
        
        # If patience is exhausted, reduce the learning rate
        if self.counter >= self.patience:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] /= self.factor  # Reduce learning rate by the factor
            print(f"Reducing learning rate to {self.optimizer.param_groups[0]['lr']:.5f}")
            self.counter = 0  # Reset counter after reducing learning rate

    def get_last_lr(self):
        return self.optimizer.param_groups[0]['lr']