Spaces:
Configuration error
Configuration error
File size: 8,223 Bytes
3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | """Typed configuration schemas (Pydantic v2 ``BaseSettings``).
These classes replace the bare globals ``MAX_LENGTH``, ``BATCH_SIZE``, ... that
the notebook holds in cell 6. The advantages of doing this:
1. **Type safety** β every field has a declared type and Pydantic validates
it at load time. A YAML typo (``batch_size: "64"`` as a string) raises an
error pointing at the file and field, not a mysterious training failure
six steps later.
2. **Env override** β ``CAPTIONING__TRAIN__BATCH_SIZE=32`` overrides
``train.batch_size`` without editing YAML. The double underscore is the
nesting delimiter (configurable below). Useful for CI smoke tests.
3. **Single source of truth** β every other module accepts a sub-config
(``ModelConfig``, ``TrainConfig``, ...) instead of pulling globals. That
makes them testable in isolation and trivially overridable in serve.
The schema mirrors the IEEE notebook 1:1 β same field names where reasonable,
same default values. Extending it (Phase 1b: warmup/cosine LR; Phase 3: model
registry) only adds new fields, never changes the meaning of existing ones.
"""
from __future__ import annotations
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class _StrictModel(BaseModel):
"""Shared base for every sub-config β rejects unknown keys.
Pydantic's default ``extra="ignore"`` silently drops misspelled fields.
For configs that drive ML hyperparameters that's the worst possible
behaviour: a typo (``vocabularsy_size`` instead of ``vocabulary_size``)
silently uses the default and the model trains with the wrong value.
Forbidding extras turns every typo into a load-time error pointing at
the offending field.
Note: ``extra="forbid"`` is set on ``AppConfig`` separately because
``BaseSettings`` uses ``SettingsConfigDict``, not ``ConfigDict``.
"""
model_config = ConfigDict(extra="forbid")
class DataConfig(_StrictModel):
"""Where the dataset lives and how much of it to use.
Attributes:
base_path: Root of the COCO dataset. Mirrors the notebook's
``BASE_PATH = '../input/coco-2017-dataset/coco2017'``.
annotations_filename: Name of the captions JSON inside ``annotations/``.
images_subdir: Sub-folder under ``base_path`` containing JPEGs.
sample_size: How many caption pairs to sample. The notebook samples
120k. Set to ``-1`` to use the full set.
train_val_split: Fraction of *images* (not captions) used for training.
Splitting at the image level prevents the same image appearing in
both splits via different captions β a real leakage source.
"""
base_path: Path = Path("data/coco2017")
annotations_filename: str = "captions_train2017.json"
images_subdir: str = "train2017"
sample_size: int = 120_000
train_val_split: float = 0.8
@field_validator("train_val_split")
@classmethod
def _validate_split(cls, v: float) -> float:
if not 0.0 < v < 1.0:
raise ValueError(f"train_val_split must be in (0, 1), got {v}")
return v
class ModelConfig(_StrictModel):
"""Architecture hyperparameters.
Defaults match the IEEE paper / notebook cell 6 exactly. Changing any of
these requires re-training and re-publishing the model card on HF Hub.
"""
embedding_dim: int = 512
units: int = 512
max_length: int = 40
vocabulary_size: int = 15_000
encoder_num_heads: int = 1 # Notebook cell 21: TransformerEncoderLayer(EMBEDDING_DIM, 1)
decoder_num_heads: int = 8 # Notebook cell 21: TransformerDecoderLayer(..., 8)
decoder_dropout_inner: float = 0.3 # Notebook cell 19: dropout_1
decoder_dropout_outer: float = 0.5 # Notebook cell 19: dropout_2
decoder_attention_dropout: float = 0.1 # Notebook cell 19: MultiHeadAttention(dropout=0.1)
class TrainConfig(_StrictModel):
"""Optimisation hyperparameters.
The Phase 1 baseline mirrors the IEEE notebook: constant LR, no label
smoothing, dropout-active validation (a notebook quirk preserved for
parity). The fields below the comment line are *opt-in* training-
stability knobs added during the caption-quality stabilisation phase.
Defaults keep every existing run byte-for-byte identical to the
notebook; flipping the flag in YAML opts a run into the modern recipe.
"""
epochs: int = 10
batch_size: int = 64
buffer_size: int = 1_000 # tf.data shuffle buffer
early_stopping_patience: int = 3
seed: int = 42 # NEW (not in notebook): pin RNGs for reproducibility
learning_rate: float = 1e-3 # Notebook uses Keras Adam default == 1e-3
weights_filename: str = "model.h5"
# ---- opt-in stability flags (default values preserve notebook parity) ----
label_smoothing: float = 0.0
lr_schedule: str = "constant" # "constant" | "cosine"
warmup_steps: int = 0
cosine_decay_steps: int | None = None # If None, derived from epochs * steps_per_epoch
min_learning_rate: float = 0.0
honour_training_flag_in_test_step: bool = False # parity-quirk override
@field_validator("label_smoothing")
@classmethod
def _validate_label_smoothing(cls, v: float) -> float:
if not 0.0 <= v < 1.0:
raise ValueError(f"label_smoothing must be in [0, 1), got {v}")
return v
@field_validator("lr_schedule")
@classmethod
def _validate_lr_schedule(cls, v: str) -> str:
if v not in {"constant", "cosine"}:
raise ValueError(f"lr_schedule must be 'constant' or 'cosine', got {v!r}")
return v
@field_validator("warmup_steps")
@classmethod
def _validate_warmup_steps(cls, v: int) -> int:
if v < 0:
raise ValueError(f"warmup_steps must be >= 0, got {v}")
return v
class ServeConfig(_StrictModel):
"""Settings for the FastAPI backend (Phase 2). Defined here so the schema
is complete and tests don't have to mock a sub-config's existence.
Decoding-related defaults are deliberately conservative: ``greedy`` stays
the default for byte-for-byte parity with the IEEE notebook. Switching to
beam at deploy time is a one-line YAML override:
serve:
decode_strategy: beam
beam_width: 4
length_penalty: 0.7
repetition_penalty: 1.1
no_repeat_ngram_size: 3
"""
max_upload_bytes: int = 10 * 1024 * 1024 # 10 MB
decode_strategy: str = "greedy"
beam_width: int = 3
length_penalty: float = 1.0
repetition_penalty: float = 1.0
no_repeat_ngram_size: int = 0
cors_allowed_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"])
@field_validator("decode_strategy")
@classmethod
def _validate_decode_strategy(cls, v: str) -> str:
if v not in {"greedy", "beam"}:
raise ValueError(f"decode_strategy must be 'greedy' or 'beam', got {v!r}")
return v
@field_validator("beam_width")
@classmethod
def _validate_beam_width(cls, v: int) -> int:
if v < 1:
raise ValueError(f"beam_width must be >= 1, got {v}")
return v
@field_validator("repetition_penalty")
@classmethod
def _validate_repetition_penalty(cls, v: float) -> float:
if v < 1.0:
raise ValueError(f"repetition_penalty must be >= 1.0 (1.0 disables it), got {v}")
return v
class AppConfig(BaseSettings):
"""Top-level config aggregating every sub-config.
Loaded by ``captioning.config.loader.load_config(yaml_path)``. Env vars
with prefix ``CAPTIONING__`` override fields at any depth.
"""
data: DataConfig = Field(default_factory=DataConfig)
model: ModelConfig = Field(default_factory=ModelConfig)
train: TrainConfig = Field(default_factory=TrainConfig)
serve: ServeConfig = Field(default_factory=ServeConfig)
model_config = SettingsConfigDict(
env_prefix="CAPTIONING__",
env_nested_delimiter="__",
case_sensitive=False,
extra="forbid", # Reject unknown keys β catches typos at load time
)
|