File size: 8,223 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
91a1214
3a2e5f0
91a1214
 
 
3a2e5f0
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""Typed configuration schemas (Pydantic v2 ``BaseSettings``).

These classes replace the bare globals ``MAX_LENGTH``, ``BATCH_SIZE``, ... that
the notebook holds in cell 6. The advantages of doing this:

1. **Type safety** β€” every field has a declared type and Pydantic validates
   it at load time. A YAML typo (``batch_size: "64"`` as a string) raises an
   error pointing at the file and field, not a mysterious training failure
   six steps later.
2. **Env override** β€” ``CAPTIONING__TRAIN__BATCH_SIZE=32`` overrides
   ``train.batch_size`` without editing YAML. The double underscore is the
   nesting delimiter (configurable below). Useful for CI smoke tests.
3. **Single source of truth** β€” every other module accepts a sub-config
   (``ModelConfig``, ``TrainConfig``, ...) instead of pulling globals. That
   makes them testable in isolation and trivially overridable in serve.

The schema mirrors the IEEE notebook 1:1 β€” same field names where reasonable,
same default values. Extending it (Phase 1b: warmup/cosine LR; Phase 3: model
registry) only adds new fields, never changes the meaning of existing ones.
"""

from __future__ import annotations

from pathlib import Path

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict


class _StrictModel(BaseModel):
    """Shared base for every sub-config β€” rejects unknown keys.

    Pydantic's default ``extra="ignore"`` silently drops misspelled fields.
    For configs that drive ML hyperparameters that's the worst possible
    behaviour: a typo (``vocabularsy_size`` instead of ``vocabulary_size``)
    silently uses the default and the model trains with the wrong value.
    Forbidding extras turns every typo into a load-time error pointing at
    the offending field.

    Note: ``extra="forbid"`` is set on ``AppConfig`` separately because
    ``BaseSettings`` uses ``SettingsConfigDict``, not ``ConfigDict``.
    """

    model_config = ConfigDict(extra="forbid")


class DataConfig(_StrictModel):
    """Where the dataset lives and how much of it to use.

    Attributes:
        base_path: Root of the COCO dataset. Mirrors the notebook's
            ``BASE_PATH = '../input/coco-2017-dataset/coco2017'``.
        annotations_filename: Name of the captions JSON inside ``annotations/``.
        images_subdir: Sub-folder under ``base_path`` containing JPEGs.
        sample_size: How many caption pairs to sample. The notebook samples
            120k. Set to ``-1`` to use the full set.
        train_val_split: Fraction of *images* (not captions) used for training.
            Splitting at the image level prevents the same image appearing in
            both splits via different captions β€” a real leakage source.
    """

    base_path: Path = Path("data/coco2017")
    annotations_filename: str = "captions_train2017.json"
    images_subdir: str = "train2017"
    sample_size: int = 120_000
    train_val_split: float = 0.8

    @field_validator("train_val_split")
    @classmethod
    def _validate_split(cls, v: float) -> float:
        if not 0.0 < v < 1.0:
            raise ValueError(f"train_val_split must be in (0, 1), got {v}")
        return v


class ModelConfig(_StrictModel):
    """Architecture hyperparameters.

    Defaults match the IEEE paper / notebook cell 6 exactly. Changing any of
    these requires re-training and re-publishing the model card on HF Hub.
    """

    embedding_dim: int = 512
    units: int = 512
    max_length: int = 40
    vocabulary_size: int = 15_000
    encoder_num_heads: int = 1  # Notebook cell 21: TransformerEncoderLayer(EMBEDDING_DIM, 1)
    decoder_num_heads: int = 8  # Notebook cell 21: TransformerDecoderLayer(..., 8)
    decoder_dropout_inner: float = 0.3  # Notebook cell 19: dropout_1
    decoder_dropout_outer: float = 0.5  # Notebook cell 19: dropout_2
    decoder_attention_dropout: float = 0.1  # Notebook cell 19: MultiHeadAttention(dropout=0.1)


class TrainConfig(_StrictModel):
    """Optimisation hyperparameters.

    The Phase 1 baseline mirrors the IEEE notebook: constant LR, no label
    smoothing, dropout-active validation (a notebook quirk preserved for
    parity). The fields below the comment line are *opt-in* training-
    stability knobs added during the caption-quality stabilisation phase.
    Defaults keep every existing run byte-for-byte identical to the
    notebook; flipping the flag in YAML opts a run into the modern recipe.
    """

    epochs: int = 10
    batch_size: int = 64
    buffer_size: int = 1_000  # tf.data shuffle buffer
    early_stopping_patience: int = 3
    seed: int = 42  # NEW (not in notebook): pin RNGs for reproducibility
    learning_rate: float = 1e-3  # Notebook uses Keras Adam default == 1e-3
    weights_filename: str = "model.h5"

    # ---- opt-in stability flags (default values preserve notebook parity) ----
    label_smoothing: float = 0.0
    lr_schedule: str = "constant"  # "constant" | "cosine"
    warmup_steps: int = 0
    cosine_decay_steps: int | None = None  # If None, derived from epochs * steps_per_epoch
    min_learning_rate: float = 0.0
    honour_training_flag_in_test_step: bool = False  # parity-quirk override

    @field_validator("label_smoothing")
    @classmethod
    def _validate_label_smoothing(cls, v: float) -> float:
        if not 0.0 <= v < 1.0:
            raise ValueError(f"label_smoothing must be in [0, 1), got {v}")
        return v

    @field_validator("lr_schedule")
    @classmethod
    def _validate_lr_schedule(cls, v: str) -> str:
        if v not in {"constant", "cosine"}:
            raise ValueError(f"lr_schedule must be 'constant' or 'cosine', got {v!r}")
        return v

    @field_validator("warmup_steps")
    @classmethod
    def _validate_warmup_steps(cls, v: int) -> int:
        if v < 0:
            raise ValueError(f"warmup_steps must be >= 0, got {v}")
        return v


class ServeConfig(_StrictModel):
    """Settings for the FastAPI backend (Phase 2). Defined here so the schema
    is complete and tests don't have to mock a sub-config's existence.

    Decoding-related defaults are deliberately conservative: ``greedy`` stays
    the default for byte-for-byte parity with the IEEE notebook. Switching to
    beam at deploy time is a one-line YAML override:

        serve:
          decode_strategy: beam
          beam_width: 4
          length_penalty: 0.7
          repetition_penalty: 1.1
          no_repeat_ngram_size: 3
    """

    max_upload_bytes: int = 10 * 1024 * 1024  # 10 MB
    decode_strategy: str = "greedy"
    beam_width: int = 3
    length_penalty: float = 1.0
    repetition_penalty: float = 1.0
    no_repeat_ngram_size: int = 0
    cors_allowed_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"])

    @field_validator("decode_strategy")
    @classmethod
    def _validate_decode_strategy(cls, v: str) -> str:
        if v not in {"greedy", "beam"}:
            raise ValueError(f"decode_strategy must be 'greedy' or 'beam', got {v!r}")
        return v

    @field_validator("beam_width")
    @classmethod
    def _validate_beam_width(cls, v: int) -> int:
        if v < 1:
            raise ValueError(f"beam_width must be >= 1, got {v}")
        return v

    @field_validator("repetition_penalty")
    @classmethod
    def _validate_repetition_penalty(cls, v: float) -> float:
        if v < 1.0:
            raise ValueError(f"repetition_penalty must be >= 1.0 (1.0 disables it), got {v}")
        return v


class AppConfig(BaseSettings):
    """Top-level config aggregating every sub-config.

    Loaded by ``captioning.config.loader.load_config(yaml_path)``. Env vars
    with prefix ``CAPTIONING__`` override fields at any depth.
    """

    data: DataConfig = Field(default_factory=DataConfig)
    model: ModelConfig = Field(default_factory=ModelConfig)
    train: TrainConfig = Field(default_factory=TrainConfig)
    serve: ServeConfig = Field(default_factory=ServeConfig)

    model_config = SettingsConfigDict(
        env_prefix="CAPTIONING__",
        env_nested_delimiter="__",
        case_sensitive=False,
        extra="forbid",  # Reject unknown keys β€” catches typos at load time
    )