apoorvrajdev's picture
feat(evaluation): add beam search, metrics pipeline, and stabilized training workflow
91a1214
"""Learning-rate schedules.
The baseline pipeline uses a constant Adam LR (matching the IEEE notebook),
which is fine for short fine-tuning runs but tends to leave Transformer
captioners in a mediocre local minimum: the LR is too aggressive at start
(decoder weights are still random) and too high near convergence (the model
oscillates around a flat basin instead of settling).
The fix the literature converged on is linear warmup followed by cosine
decay (the GPT/BERT/ViT recipe):
lr(step) = peak_lr * step / warmup_steps if step < warmup
lr(step) = min_lr + (peak_lr - min_lr) * 0.5 *
(1 + cos(pi * (step - warmup) / decay_steps)) otherwise
We implement it as a ``LearningRateSchedule`` so the optimizer can call it
per-step automatically, without us having to track step counts manually.
"""
from __future__ import annotations
def _build_warmup_cosine_class():
"""Lazy-build the schedule class to keep TF off the package import path."""
import tensorflow as tf
class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Linear warmup followed by cosine decay to ``min_learning_rate``.
Args:
peak_learning_rate: Maximum LR reached at the end of warmup.
warmup_steps: Number of steps to linearly ramp from 0 to peak.
``0`` disables warmup (starts directly at ``peak``).
decay_steps: Number of steps over which to cosine-decay from
``peak`` to ``min_learning_rate`` after warmup.
min_learning_rate: Floor reached at the end of decay (and held
thereafter).
"""
def __init__(
self,
peak_learning_rate: float,
warmup_steps: int,
decay_steps: int,
min_learning_rate: float = 0.0,
) -> None:
super().__init__()
self.peak_learning_rate = float(peak_learning_rate)
self.warmup_steps = int(warmup_steps)
self.decay_steps = max(int(decay_steps), 1)
self.min_learning_rate = float(min_learning_rate)
def __call__(self, step):
step = tf.cast(step, tf.float32)
peak = tf.constant(self.peak_learning_rate, dtype=tf.float32)
floor = tf.constant(self.min_learning_rate, dtype=tf.float32)
warmup = tf.constant(float(self.warmup_steps), dtype=tf.float32)
decay = tf.constant(float(self.decay_steps), dtype=tf.float32)
# During warmup: linear ramp 0 -> peak.
warmup_lr = peak * tf.math.divide_no_nan(step, warmup)
# After warmup: cosine decay peak -> floor over decay_steps.
progress = tf.minimum(1.0, tf.math.divide_no_nan(step - warmup, decay))
cosine = 0.5 * (1.0 + tf.cos(tf.constant(3.141592653589793) * progress))
decay_lr = floor + (peak - floor) * cosine
return tf.where(step < warmup, warmup_lr, decay_lr)
def get_config(self) -> dict[str, float | int]:
return {
"peak_learning_rate": self.peak_learning_rate,
"warmup_steps": self.warmup_steps,
"decay_steps": self.decay_steps,
"min_learning_rate": self.min_learning_rate,
}
return WarmupCosineDecay
WarmupCosineDecay = _build_warmup_cosine_class()
def build_learning_rate(
*,
schedule: str,
peak_learning_rate: float,
warmup_steps: int,
decay_steps: int,
min_learning_rate: float,
):
"""Return either a float (constant LR) or a :class:`WarmupCosineDecay`.
The optimizer treats a float as a fixed LR and a ``LearningRateSchedule``
as a per-step callable — we hide that asymmetry behind this factory so
the trainer only ever passes ``learning_rate=build_learning_rate(...)``.
"""
if schedule == "constant":
return peak_learning_rate
if schedule == "cosine":
return WarmupCosineDecay(
peak_learning_rate=peak_learning_rate,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
min_learning_rate=min_learning_rate,
)
raise ValueError(f"unsupported lr_schedule: {schedule!r}")