apoorvrajdev's picture
feat(evaluation): add beam search, metrics pipeline, and stabilized training workflow
91a1214
"""``Trainer`` — orchestration around ``model.compile + model.fit``.
Wraps notebook cells 22 and 23 in a class so:
* Tests can construct a Trainer with a tiny dataset and assert
``trainer.fit`` returns a sensible history dict.
* Phase 4 can replace the trainer with a CLI-driven main loop without
changing the notebook-equivalent behaviour.
The trainer reads the optional training-stability fields off ``TrainConfig``
(``label_smoothing``, ``lr_schedule``, ``warmup_steps``, ...). With defaults
in place every existing config produces a byte-identical compile call to the
notebook; flipping one YAML flag opts a run into the modern recipe without
touching code.
"""
from __future__ import annotations
import json
from pathlib import Path
from captioning.config.schema import AppConfig
from captioning.training.callbacks import default_callbacks
from captioning.training.losses import build_loss
from captioning.training.schedules import build_learning_rate
from captioning.utils.logging import get_logger
log = get_logger(__name__)
def _infer_steps_per_epoch(dataset) -> int | None:
"""Best-effort cardinality probe for a ``tf.data.Dataset``.
Returns ``None`` when the dataset's cardinality is unknown or infinite.
Used only to derive ``cosine_decay_steps`` when the user didn't pin it
explicitly.
"""
try:
import tensorflow as tf
card = int(tf.data.experimental.cardinality(dataset).numpy())
except Exception: # — cardinality probing is best-effort
return None
if card in (-1, -2): # UNKNOWN, INFINITE
return None
return card
class Trainer:
"""Thin orchestration layer around an ``ImageCaptioningModel``."""
def __init__(self, model, config: AppConfig) -> None:
"""Args:
model: Result of ``build_caption_model(config, vocab_size)``.
config: Validated ``AppConfig``.
"""
self.model = model
self.config = config
self._compiled = False
def compile(self, *, steps_per_epoch: int | None = None) -> None:
"""Build optimizer + loss from config and call ``model.compile``.
Args:
steps_per_epoch: Used to derive ``cosine_decay_steps`` when the
config doesn't pin it explicitly. Passing ``None`` falls back
to the config value (or 1 if neither is set — degenerates to
immediate floor LR, but still a well-defined schedule).
"""
import tensorflow as tf
train = self.config.train
# Vocab size lives on the decoder's final Dense layer; pulling it here
# avoids threading the tokenizer through the trainer just for loss.
vocab_size = int(self.model.decoder.out.units)
loss = build_loss(train.label_smoothing, vocab_size)
cosine_steps = train.cosine_decay_steps or (
(steps_per_epoch or 1) * max(train.epochs - 0, 1)
)
learning_rate = build_learning_rate(
schedule=train.lr_schedule,
peak_learning_rate=train.learning_rate,
warmup_steps=train.warmup_steps,
decay_steps=cosine_steps,
min_learning_rate=train.min_learning_rate,
)
self.model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=loss,
)
self._compiled = True
log.info(
"model_compiled",
lr_schedule=train.lr_schedule,
peak_learning_rate=train.learning_rate,
warmup_steps=train.warmup_steps,
cosine_decay_steps=cosine_steps,
label_smoothing=train.label_smoothing,
)
def fit(
self,
train_dataset,
val_dataset,
*,
output_dir: str | Path | None = None,
) -> dict[str, list[float]]:
"""Run ``model.fit`` and return a history dict.
Args:
train_dataset: ``tf.data.Dataset`` from
``data.pipeline.build_train_pipeline``.
val_dataset: ``tf.data.Dataset`` from
``data.pipeline.build_val_pipeline``.
output_dir: If provided, callbacks write ``best.h5`` and
``training_log.csv`` here, and ``history.json`` is dumped at
the end.
Returns:
``history.history`` as a ``dict[str, list[float]]``.
"""
if not self._compiled:
self.compile(steps_per_epoch=_infer_steps_per_epoch(train_dataset))
callbacks = default_callbacks(self.config, output_dir=output_dir)
log.info("fit_start", epochs=self.config.train.epochs)
history = self.model.fit(
train_dataset,
epochs=self.config.train.epochs,
validation_data=val_dataset,
callbacks=callbacks,
)
log.info("fit_end", final_loss=history.history.get("loss", [None])[-1])
if output_dir is not None:
history_path = Path(output_dir) / "history.json"
with history_path.open("w", encoding="utf-8") as f:
json.dump(history.history, f, indent=2)
return dict(history.history)