Spaces:
Configuration error
Configuration error
File size: 5,179 Bytes
3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """``Trainer`` — orchestration around ``model.compile + model.fit``.
Wraps notebook cells 22 and 23 in a class so:
* Tests can construct a Trainer with a tiny dataset and assert
``trainer.fit`` returns a sensible history dict.
* Phase 4 can replace the trainer with a CLI-driven main loop without
changing the notebook-equivalent behaviour.
The trainer reads the optional training-stability fields off ``TrainConfig``
(``label_smoothing``, ``lr_schedule``, ``warmup_steps``, ...). With defaults
in place every existing config produces a byte-identical compile call to the
notebook; flipping one YAML flag opts a run into the modern recipe without
touching code.
"""
from __future__ import annotations
import json
from pathlib import Path
from captioning.config.schema import AppConfig
from captioning.training.callbacks import default_callbacks
from captioning.training.losses import build_loss
from captioning.training.schedules import build_learning_rate
from captioning.utils.logging import get_logger
log = get_logger(__name__)
def _infer_steps_per_epoch(dataset) -> int | None:
"""Best-effort cardinality probe for a ``tf.data.Dataset``.
Returns ``None`` when the dataset's cardinality is unknown or infinite.
Used only to derive ``cosine_decay_steps`` when the user didn't pin it
explicitly.
"""
try:
import tensorflow as tf
card = int(tf.data.experimental.cardinality(dataset).numpy())
except Exception: # — cardinality probing is best-effort
return None
if card in (-1, -2): # UNKNOWN, INFINITE
return None
return card
class Trainer:
"""Thin orchestration layer around an ``ImageCaptioningModel``."""
def __init__(self, model, config: AppConfig) -> None:
"""Args:
model: Result of ``build_caption_model(config, vocab_size)``.
config: Validated ``AppConfig``.
"""
self.model = model
self.config = config
self._compiled = False
def compile(self, *, steps_per_epoch: int | None = None) -> None:
"""Build optimizer + loss from config and call ``model.compile``.
Args:
steps_per_epoch: Used to derive ``cosine_decay_steps`` when the
config doesn't pin it explicitly. Passing ``None`` falls back
to the config value (or 1 if neither is set — degenerates to
immediate floor LR, but still a well-defined schedule).
"""
import tensorflow as tf
train = self.config.train
# Vocab size lives on the decoder's final Dense layer; pulling it here
# avoids threading the tokenizer through the trainer just for loss.
vocab_size = int(self.model.decoder.out.units)
loss = build_loss(train.label_smoothing, vocab_size)
cosine_steps = train.cosine_decay_steps or (
(steps_per_epoch or 1) * max(train.epochs - 0, 1)
)
learning_rate = build_learning_rate(
schedule=train.lr_schedule,
peak_learning_rate=train.learning_rate,
warmup_steps=train.warmup_steps,
decay_steps=cosine_steps,
min_learning_rate=train.min_learning_rate,
)
self.model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=loss,
)
self._compiled = True
log.info(
"model_compiled",
lr_schedule=train.lr_schedule,
peak_learning_rate=train.learning_rate,
warmup_steps=train.warmup_steps,
cosine_decay_steps=cosine_steps,
label_smoothing=train.label_smoothing,
)
def fit(
self,
train_dataset,
val_dataset,
*,
output_dir: str | Path | None = None,
) -> dict[str, list[float]]:
"""Run ``model.fit`` and return a history dict.
Args:
train_dataset: ``tf.data.Dataset`` from
``data.pipeline.build_train_pipeline``.
val_dataset: ``tf.data.Dataset`` from
``data.pipeline.build_val_pipeline``.
output_dir: If provided, callbacks write ``best.h5`` and
``training_log.csv`` here, and ``history.json`` is dumped at
the end.
Returns:
``history.history`` as a ``dict[str, list[float]]``.
"""
if not self._compiled:
self.compile(steps_per_epoch=_infer_steps_per_epoch(train_dataset))
callbacks = default_callbacks(self.config, output_dir=output_dir)
log.info("fit_start", epochs=self.config.train.epochs)
history = self.model.fit(
train_dataset,
epochs=self.config.train.epochs,
validation_data=val_dataset,
callbacks=callbacks,
)
log.info("fit_end", final_loss=history.history.get("loss", [None])[-1])
if output_dir is not None:
history_path = Path(output_dir) / "history.json"
with history_path.open("w", encoding="utf-8") as f:
json.dump(history.history, f, indent=2)
return dict(history.history)
|