File size: 3,543 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Train the IEEE InceptionV3+Transformer captioning model.

Usage:
    python -m scripts.train --config configs/base.yaml
    python -m scripts.train --config configs/base.yaml --output-dir models/v1.0.0

The script orchestrates the same pipeline as the notebook, but each step is
imported from the modular package — making it the canonical example of how
the package is meant to be composed.
"""

from __future__ import annotations

from pathlib import Path

import click

from captioning.config import load_config
from captioning.data import (
    build_train_pipeline,
    build_val_pipeline,
    load_coco_annotations,
    make_image_level_splits,
)
from captioning.models import build_caption_model
from captioning.preprocessing import CaptionTokenizer, preprocess_caption
from captioning.training import Trainer
from captioning.utils import configure_logging, get_logger, set_global_seed

log = get_logger(__name__)


@click.command()
@click.option(
    "--config",
    "config_path",
    required=True,
    type=click.Path(exists=True, dir_okay=False, path_type=Path),
    help="YAML config file (e.g. configs/base.yaml).",
)
@click.option(
    "--output-dir",
    type=click.Path(path_type=Path),
    default="outputs/runs/latest",
    help="Where to save weights, vocab, and history.",
)
def main(config_path: Path, output_dir: Path) -> None:
    """Run the full training pipeline end-to-end."""
    configure_logging()
    config = load_config(config_path)
    output_dir.mkdir(parents=True, exist_ok=True)

    set_global_seed(config.train.seed)
    log.info("config_loaded", path=str(config_path), output_dir=str(output_dir))

    # 1. Load + preprocess COCO captions ------------------------------------
    df = load_coco_annotations(
        base_path=config.data.base_path,
        annotations_filename=config.data.annotations_filename,
        images_subdir=config.data.images_subdir,
        sample_size=config.data.sample_size,
        seed=config.train.seed,
        caption_preprocessor=preprocess_caption,
    )

    # 2. Fit and persist the tokenizer --------------------------------------
    tokenizer = CaptionTokenizer(
        vocab_size=config.model.vocabulary_size,
        max_length=config.model.max_length,
    )
    tokenizer.fit(df["caption"])
    tokenizer.save(output_dir)

    # 3. Image-level train/val split ----------------------------------------
    train_imgs, train_caps, val_imgs, val_caps = make_image_level_splits(
        df, train_fraction=config.data.train_val_split, seed=config.train.seed
    )

    # 4. tf.data pipelines ---------------------------------------------------
    train_ds = build_train_pipeline(
        train_imgs,
        train_caps,
        tokenizer,
        batch_size=config.train.batch_size,
        buffer_size=config.train.buffer_size,
    )
    val_ds = build_val_pipeline(
        val_imgs,
        val_caps,
        tokenizer,
        batch_size=config.train.batch_size,
        buffer_size=config.train.buffer_size,
    )

    # 5. Build, compile, fit -------------------------------------------------
    model = build_caption_model(config, vocab_size=tokenizer.vocabulary_size)
    trainer = Trainer(model, config)
    trainer.fit(train_ds, val_ds, output_dir=output_dir)

    # 6. Save final weights to the canonical filename ------------------------
    final_weights = output_dir / config.train.weights_filename
    model.save_weights(str(final_weights))
    log.info("training_done", weights=str(final_weights))


if __name__ == "__main__":
    main()