File size: 1,195 Bytes
91a1214
3a2e5f0
 
 
91a1214
3a2e5f0
91a1214
 
 
 
3a2e5f0
 
 
 
91a1214
 
 
 
 
 
3a2e5f0
 
 
 
91a1214
 
 
3a2e5f0
91a1214
3a2e5f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""Training — losses, schedules, callbacks, and the trainer.

The notebook computes loss + masked accuracy inside the model's ``train_step``;
we keep that structure for parity but expose the loss function and callbacks
as standalone modules so they can be unit-tested and reused.

    losses.py      ``masked_sparse_categorical_crossentropy`` (baseline) +
                   ``label_smoothed_crossentropy`` + ``build_loss``
    schedules.py   ``WarmupCosineDecay`` + ``build_learning_rate``
    callbacks.py   ``default_callbacks(config)`` — early stopping + checkpoint
    trainer.py     ``Trainer.fit()`` — wraps compile + fit + history serialization
"""

from captioning.training.callbacks import default_callbacks
from captioning.training.losses import (
    build_loss,
    label_smoothed_crossentropy,
    masked_sparse_categorical_crossentropy,
)
from captioning.training.schedules import WarmupCosineDecay, build_learning_rate
from captioning.training.trainer import Trainer

__all__ = [
    "Trainer",
    "WarmupCosineDecay",
    "build_learning_rate",
    "build_loss",
    "default_callbacks",
    "label_smoothed_crossentropy",
    "masked_sparse_categorical_crossentropy",
]