Spaces:
Configuration error
Configuration error
| """Learning-rate schedules. | |
| The baseline pipeline uses a constant Adam LR (matching the IEEE notebook), | |
| which is fine for short fine-tuning runs but tends to leave Transformer | |
| captioners in a mediocre local minimum: the LR is too aggressive at start | |
| (decoder weights are still random) and too high near convergence (the model | |
| oscillates around a flat basin instead of settling). | |
| The fix the literature converged on is linear warmup followed by cosine | |
| decay (the GPT/BERT/ViT recipe): | |
| lr(step) = peak_lr * step / warmup_steps if step < warmup | |
| lr(step) = min_lr + (peak_lr - min_lr) * 0.5 * | |
| (1 + cos(pi * (step - warmup) / decay_steps)) otherwise | |
| We implement it as a ``LearningRateSchedule`` so the optimizer can call it | |
| per-step automatically, without us having to track step counts manually. | |
| """ | |
| from __future__ import annotations | |
| def _build_warmup_cosine_class(): | |
| """Lazy-build the schedule class to keep TF off the package import path.""" | |
| import tensorflow as tf | |
| class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule): | |
| """Linear warmup followed by cosine decay to ``min_learning_rate``. | |
| Args: | |
| peak_learning_rate: Maximum LR reached at the end of warmup. | |
| warmup_steps: Number of steps to linearly ramp from 0 to peak. | |
| ``0`` disables warmup (starts directly at ``peak``). | |
| decay_steps: Number of steps over which to cosine-decay from | |
| ``peak`` to ``min_learning_rate`` after warmup. | |
| min_learning_rate: Floor reached at the end of decay (and held | |
| thereafter). | |
| """ | |
| def __init__( | |
| self, | |
| peak_learning_rate: float, | |
| warmup_steps: int, | |
| decay_steps: int, | |
| min_learning_rate: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| self.peak_learning_rate = float(peak_learning_rate) | |
| self.warmup_steps = int(warmup_steps) | |
| self.decay_steps = max(int(decay_steps), 1) | |
| self.min_learning_rate = float(min_learning_rate) | |
| def __call__(self, step): | |
| step = tf.cast(step, tf.float32) | |
| peak = tf.constant(self.peak_learning_rate, dtype=tf.float32) | |
| floor = tf.constant(self.min_learning_rate, dtype=tf.float32) | |
| warmup = tf.constant(float(self.warmup_steps), dtype=tf.float32) | |
| decay = tf.constant(float(self.decay_steps), dtype=tf.float32) | |
| # During warmup: linear ramp 0 -> peak. | |
| warmup_lr = peak * tf.math.divide_no_nan(step, warmup) | |
| # After warmup: cosine decay peak -> floor over decay_steps. | |
| progress = tf.minimum(1.0, tf.math.divide_no_nan(step - warmup, decay)) | |
| cosine = 0.5 * (1.0 + tf.cos(tf.constant(3.141592653589793) * progress)) | |
| decay_lr = floor + (peak - floor) * cosine | |
| return tf.where(step < warmup, warmup_lr, decay_lr) | |
| def get_config(self) -> dict[str, float | int]: | |
| return { | |
| "peak_learning_rate": self.peak_learning_rate, | |
| "warmup_steps": self.warmup_steps, | |
| "decay_steps": self.decay_steps, | |
| "min_learning_rate": self.min_learning_rate, | |
| } | |
| return WarmupCosineDecay | |
| WarmupCosineDecay = _build_warmup_cosine_class() | |
| def build_learning_rate( | |
| *, | |
| schedule: str, | |
| peak_learning_rate: float, | |
| warmup_steps: int, | |
| decay_steps: int, | |
| min_learning_rate: float, | |
| ): | |
| """Return either a float (constant LR) or a :class:`WarmupCosineDecay`. | |
| The optimizer treats a float as a fixed LR and a ``LearningRateSchedule`` | |
| as a per-step callable — we hide that asymmetry behind this factory so | |
| the trainer only ever passes ``learning_rate=build_learning_rate(...)``. | |
| """ | |
| if schedule == "constant": | |
| return peak_learning_rate | |
| if schedule == "cosine": | |
| return WarmupCosineDecay( | |
| peak_learning_rate=peak_learning_rate, | |
| warmup_steps=warmup_steps, | |
| decay_steps=decay_steps, | |
| min_learning_rate=min_learning_rate, | |
| ) | |
| raise ValueError(f"unsupported lr_schedule: {schedule!r}") | |