File size: 8,799 Bytes
1f5470c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | from keras.src import initializers
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.optimizers import optimizer
@keras_export(["keras.optimizers.Ftrl"])
class Ftrl(optimizer.Optimizer):
r"""Optimizer that implements the FTRL algorithm.
"Follow The Regularized Leader" (FTRL) is an optimization algorithm
developed at Google for click-through rate prediction in the early 2010s. It
is most suitable for shallow models with large and sparse feature spaces.
The algorithm is described by
[McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf).
The Keras version has support for both online L2 regularization
(the L2 regularization described in the paper
above) and shrinkage-type L2 regularization
(which is the addition of an L2 penalty to the loss function).
Initialization:
```python
n = 0
sigma = 0
z = 0
```
Update rule for one variable `w`:
```python
prev_n = n
n = n + g ** 2
sigma = (n ** -lr_power - prev_n ** -lr_power) / lr
z = z + g - sigma * w
if abs(z) < lambda_1:
w = 0
else:
w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2)
```
Notation:
- `lr` is the learning rate
- `g` is the gradient for the variable
- `lambda_1` is the L1 regularization strength
- `lambda_2` is the L2 regularization strength
- `lr_power` is the power to scale n.
Check the documentation for the `l2_shrinkage_regularization_strength`
parameter for more details when shrinkage is enabled, in which case gradient
is replaced with a gradient with shrinkage.
Args:
learning_rate: A float, a
`keras.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to `0.001`.
learning_rate_power: A float value, must be less or equal to zero.
Controls how the learning rate decreases during training. Use zero
for a fixed learning rate.
initial_accumulator_value: The starting value for accumulators. Only
zero or positive values are allowed.
l1_regularization_strength: A float value, must be greater than or equal
to zero. Defaults to `0.0`.
l2_regularization_strength: A float value, must be greater than or equal
to zero. Defaults to `0.0`.
l2_shrinkage_regularization_strength: A float value, must be greater
than or equal to zero. This differs from L2 above in that the L2
above is a stabilization penalty, whereas this L2 shrinkage is a
magnitude penalty. When input is sparse shrinkage will only happen
on the active weights.
beta: A float value, representing the beta value from the paper.
Defaults to `0.0`.
{{base_optimizer_keyword_args}}
"""
def __init__(
self,
learning_rate=0.001,
learning_rate_power=-0.5,
initial_accumulator_value=0.1,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0,
l2_shrinkage_regularization_strength=0.0,
beta=0.0,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
loss_scale_factor=None,
gradient_accumulation_steps=None,
name="ftrl",
**kwargs,
):
super().__init__(
learning_rate=learning_rate,
name=name,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
loss_scale_factor=loss_scale_factor,
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs,
)
if initial_accumulator_value < 0.0:
raise ValueError(
"`initial_accumulator_value` needs to be positive or zero. "
"Received: initial_accumulator_value="
f"{initial_accumulator_value}."
)
if learning_rate_power > 0.0:
raise ValueError(
"`learning_rate_power` needs to be negative or zero. Received: "
f"learning_rate_power={learning_rate_power}."
)
if l1_regularization_strength < 0.0:
raise ValueError(
"`l1_regularization_strength` needs to be positive or zero. "
"Received: l1_regularization_strength="
f"{l1_regularization_strength}."
)
if l2_regularization_strength < 0.0:
raise ValueError(
"`l2_regularization_strength` needs to be positive or zero. "
"Received: l2_regularization_strength="
f"{l2_regularization_strength}."
)
if l2_shrinkage_regularization_strength < 0.0:
raise ValueError(
"`l2_shrinkage_regularization_strength` needs to be positive "
"or zero. Received: l2_shrinkage_regularization_strength"
f"={l2_shrinkage_regularization_strength}."
)
self.learning_rate_power = learning_rate_power
self.initial_accumulator_value = initial_accumulator_value
self.l1_regularization_strength = l1_regularization_strength
self.l2_regularization_strength = l2_regularization_strength
self.l2_shrinkage_regularization_strength = (
l2_shrinkage_regularization_strength
)
self.beta = beta
def build(self, var_list):
"""Initialize optimizer variables.
Args:
var_list: list of model variables to build Ftrl variables on.
"""
if self.built:
return
super().build(var_list)
accumulator_initializer = initializers.Constant(
self.initial_accumulator_value,
)
self._accumulators, self._linears = self.add_optimizer_variables(
var_list,
["accumulator", "linear"],
initializer=[accumulator_initializer, "zeros"],
)
def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
accum = self._accumulators[self._get_variable_index(variable)]
linear = self._linears[self._get_variable_index(variable)]
lr_power = self.learning_rate_power
l2_reg = self.l2_regularization_strength
l2_reg = l2_reg + self.beta / (2.0 * lr)
grad_to_use = ops.add(
gradient,
ops.multiply(
2 * self.l2_shrinkage_regularization_strength, variable
),
)
new_accum = ops.add(accum, ops.square(gradient))
self.assign_add(
linear,
ops.subtract(
grad_to_use,
ops.multiply(
ops.divide(
ops.subtract(
ops.power(new_accum, -lr_power),
ops.power(accum, -lr_power),
),
lr,
),
variable,
),
),
)
quadratic = ops.add(
ops.divide(ops.power(new_accum, (-lr_power)), lr), 2 * l2_reg
)
linear_clipped = ops.clip(
linear,
-self.l1_regularization_strength,
self.l1_regularization_strength,
)
self.assign(
variable,
ops.divide(ops.subtract(linear_clipped, linear), quadratic),
)
self.assign(accum, new_accum)
def get_config(self):
config = super().get_config()
config.update(
{
"learning_rate_power": self.learning_rate_power,
"initial_accumulator_value": self.initial_accumulator_value,
"l1_regularization_strength": self.l1_regularization_strength,
"l2_regularization_strength": self.l2_regularization_strength,
"l2_shrinkage_regularization_strength": self.l2_shrinkage_regularization_strength, # noqa: E501
"beta": self.beta,
}
)
return config
Ftrl.__doc__ = Ftrl.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)
|