File size: 7,786 Bytes
70c471b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class RMSprop(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
Base RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
8-bit RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
"""
32-bit RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError("RMSprop with alpha==0.0 is not supported!")
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
|