File size: 5,179 Bytes
3a2e5f0
 
 
 
 
 
 
 
91a1214
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
91a1214
 
3a2e5f0
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
3a2e5f0
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
91a1214
 
3a2e5f0
 
91a1214
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""``Trainer`` — orchestration around ``model.compile + model.fit``.

Wraps notebook cells 22 and 23 in a class so:
    * Tests can construct a Trainer with a tiny dataset and assert
      ``trainer.fit`` returns a sensible history dict.
    * Phase 4 can replace the trainer with a CLI-driven main loop without
      changing the notebook-equivalent behaviour.

The trainer reads the optional training-stability fields off ``TrainConfig``
(``label_smoothing``, ``lr_schedule``, ``warmup_steps``, ...). With defaults
in place every existing config produces a byte-identical compile call to the
notebook; flipping one YAML flag opts a run into the modern recipe without
touching code.
"""

from __future__ import annotations

import json
from pathlib import Path

from captioning.config.schema import AppConfig
from captioning.training.callbacks import default_callbacks
from captioning.training.losses import build_loss
from captioning.training.schedules import build_learning_rate
from captioning.utils.logging import get_logger

log = get_logger(__name__)


def _infer_steps_per_epoch(dataset) -> int | None:
    """Best-effort cardinality probe for a ``tf.data.Dataset``.

    Returns ``None`` when the dataset's cardinality is unknown or infinite.
    Used only to derive ``cosine_decay_steps`` when the user didn't pin it
    explicitly.
    """
    try:
        import tensorflow as tf

        card = int(tf.data.experimental.cardinality(dataset).numpy())
    except Exception:  # — cardinality probing is best-effort
        return None
    if card in (-1, -2):  # UNKNOWN, INFINITE
        return None
    return card


class Trainer:
    """Thin orchestration layer around an ``ImageCaptioningModel``."""

    def __init__(self, model, config: AppConfig) -> None:
        """Args:
        model: Result of ``build_caption_model(config, vocab_size)``.
        config: Validated ``AppConfig``.
        """
        self.model = model
        self.config = config
        self._compiled = False

    def compile(self, *, steps_per_epoch: int | None = None) -> None:
        """Build optimizer + loss from config and call ``model.compile``.

        Args:
            steps_per_epoch: Used to derive ``cosine_decay_steps`` when the
                config doesn't pin it explicitly. Passing ``None`` falls back
                to the config value (or 1 if neither is set — degenerates to
                immediate floor LR, but still a well-defined schedule).
        """
        import tensorflow as tf

        train = self.config.train
        # Vocab size lives on the decoder's final Dense layer; pulling it here
        # avoids threading the tokenizer through the trainer just for loss.
        vocab_size = int(self.model.decoder.out.units)
        loss = build_loss(train.label_smoothing, vocab_size)

        cosine_steps = train.cosine_decay_steps or (
            (steps_per_epoch or 1) * max(train.epochs - 0, 1)
        )
        learning_rate = build_learning_rate(
            schedule=train.lr_schedule,
            peak_learning_rate=train.learning_rate,
            warmup_steps=train.warmup_steps,
            decay_steps=cosine_steps,
            min_learning_rate=train.min_learning_rate,
        )

        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss=loss,
        )
        self._compiled = True
        log.info(
            "model_compiled",
            lr_schedule=train.lr_schedule,
            peak_learning_rate=train.learning_rate,
            warmup_steps=train.warmup_steps,
            cosine_decay_steps=cosine_steps,
            label_smoothing=train.label_smoothing,
        )

    def fit(
        self,
        train_dataset,
        val_dataset,
        *,
        output_dir: str | Path | None = None,
    ) -> dict[str, list[float]]:
        """Run ``model.fit`` and return a history dict.

        Args:
            train_dataset: ``tf.data.Dataset`` from
                ``data.pipeline.build_train_pipeline``.
            val_dataset: ``tf.data.Dataset`` from
                ``data.pipeline.build_val_pipeline``.
            output_dir: If provided, callbacks write ``best.h5`` and
                ``training_log.csv`` here, and ``history.json`` is dumped at
                the end.

        Returns:
            ``history.history`` as a ``dict[str, list[float]]``.
        """
        if not self._compiled:
            self.compile(steps_per_epoch=_infer_steps_per_epoch(train_dataset))

        callbacks = default_callbacks(self.config, output_dir=output_dir)
        log.info("fit_start", epochs=self.config.train.epochs)
        history = self.model.fit(
            train_dataset,
            epochs=self.config.train.epochs,
            validation_data=val_dataset,
            callbacks=callbacks,
        )
        log.info("fit_end", final_loss=history.history.get("loss", [None])[-1])

        if output_dir is not None:
            history_path = Path(output_dir) / "history.json"
            with history_path.open("w", encoding="utf-8") as f:
                json.dump(history.history, f, indent=2)

        return dict(history.history)