apoorvrajdev's picture
feat(evaluation): add beam search, metrics pipeline, and stabilized training workflow
91a1214
"""Typed configuration schemas (Pydantic v2 ``BaseSettings``).
These classes replace the bare globals ``MAX_LENGTH``, ``BATCH_SIZE``, ... that
the notebook holds in cell 6. The advantages of doing this:
1. **Type safety** — every field has a declared type and Pydantic validates
it at load time. A YAML typo (``batch_size: "64"`` as a string) raises an
error pointing at the file and field, not a mysterious training failure
six steps later.
2. **Env override** — ``CAPTIONING__TRAIN__BATCH_SIZE=32`` overrides
``train.batch_size`` without editing YAML. The double underscore is the
nesting delimiter (configurable below). Useful for CI smoke tests.
3. **Single source of truth** — every other module accepts a sub-config
(``ModelConfig``, ``TrainConfig``, ...) instead of pulling globals. That
makes them testable in isolation and trivially overridable in serve.
The schema mirrors the IEEE notebook 1:1 — same field names where reasonable,
same default values. Extending it (Phase 1b: warmup/cosine LR; Phase 3: model
registry) only adds new fields, never changes the meaning of existing ones.
"""
from __future__ import annotations
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class _StrictModel(BaseModel):
"""Shared base for every sub-config — rejects unknown keys.
Pydantic's default ``extra="ignore"`` silently drops misspelled fields.
For configs that drive ML hyperparameters that's the worst possible
behaviour: a typo (``vocabularsy_size`` instead of ``vocabulary_size``)
silently uses the default and the model trains with the wrong value.
Forbidding extras turns every typo into a load-time error pointing at
the offending field.
Note: ``extra="forbid"`` is set on ``AppConfig`` separately because
``BaseSettings`` uses ``SettingsConfigDict``, not ``ConfigDict``.
"""
model_config = ConfigDict(extra="forbid")
class DataConfig(_StrictModel):
"""Where the dataset lives and how much of it to use.
Attributes:
base_path: Root of the COCO dataset. Mirrors the notebook's
``BASE_PATH = '../input/coco-2017-dataset/coco2017'``.
annotations_filename: Name of the captions JSON inside ``annotations/``.
images_subdir: Sub-folder under ``base_path`` containing JPEGs.
sample_size: How many caption pairs to sample. The notebook samples
120k. Set to ``-1`` to use the full set.
train_val_split: Fraction of *images* (not captions) used for training.
Splitting at the image level prevents the same image appearing in
both splits via different captions — a real leakage source.
"""
base_path: Path = Path("data/coco2017")
annotations_filename: str = "captions_train2017.json"
images_subdir: str = "train2017"
sample_size: int = 120_000
train_val_split: float = 0.8
@field_validator("train_val_split")
@classmethod
def _validate_split(cls, v: float) -> float:
if not 0.0 < v < 1.0:
raise ValueError(f"train_val_split must be in (0, 1), got {v}")
return v
class ModelConfig(_StrictModel):
"""Architecture hyperparameters.
Defaults match the IEEE paper / notebook cell 6 exactly. Changing any of
these requires re-training and re-publishing the model card on HF Hub.
"""
embedding_dim: int = 512
units: int = 512
max_length: int = 40
vocabulary_size: int = 15_000
encoder_num_heads: int = 1 # Notebook cell 21: TransformerEncoderLayer(EMBEDDING_DIM, 1)
decoder_num_heads: int = 8 # Notebook cell 21: TransformerDecoderLayer(..., 8)
decoder_dropout_inner: float = 0.3 # Notebook cell 19: dropout_1
decoder_dropout_outer: float = 0.5 # Notebook cell 19: dropout_2
decoder_attention_dropout: float = 0.1 # Notebook cell 19: MultiHeadAttention(dropout=0.1)
class TrainConfig(_StrictModel):
"""Optimisation hyperparameters.
The Phase 1 baseline mirrors the IEEE notebook: constant LR, no label
smoothing, dropout-active validation (a notebook quirk preserved for
parity). The fields below the comment line are *opt-in* training-
stability knobs added during the caption-quality stabilisation phase.
Defaults keep every existing run byte-for-byte identical to the
notebook; flipping the flag in YAML opts a run into the modern recipe.
"""
epochs: int = 10
batch_size: int = 64
buffer_size: int = 1_000 # tf.data shuffle buffer
early_stopping_patience: int = 3
seed: int = 42 # NEW (not in notebook): pin RNGs for reproducibility
learning_rate: float = 1e-3 # Notebook uses Keras Adam default == 1e-3
weights_filename: str = "model.h5"
# ---- opt-in stability flags (default values preserve notebook parity) ----
label_smoothing: float = 0.0
lr_schedule: str = "constant" # "constant" | "cosine"
warmup_steps: int = 0
cosine_decay_steps: int | None = None # If None, derived from epochs * steps_per_epoch
min_learning_rate: float = 0.0
honour_training_flag_in_test_step: bool = False # parity-quirk override
@field_validator("label_smoothing")
@classmethod
def _validate_label_smoothing(cls, v: float) -> float:
if not 0.0 <= v < 1.0:
raise ValueError(f"label_smoothing must be in [0, 1), got {v}")
return v
@field_validator("lr_schedule")
@classmethod
def _validate_lr_schedule(cls, v: str) -> str:
if v not in {"constant", "cosine"}:
raise ValueError(f"lr_schedule must be 'constant' or 'cosine', got {v!r}")
return v
@field_validator("warmup_steps")
@classmethod
def _validate_warmup_steps(cls, v: int) -> int:
if v < 0:
raise ValueError(f"warmup_steps must be >= 0, got {v}")
return v
class ServeConfig(_StrictModel):
"""Settings for the FastAPI backend (Phase 2). Defined here so the schema
is complete and tests don't have to mock a sub-config's existence.
Decoding-related defaults are deliberately conservative: ``greedy`` stays
the default for byte-for-byte parity with the IEEE notebook. Switching to
beam at deploy time is a one-line YAML override:
serve:
decode_strategy: beam
beam_width: 4
length_penalty: 0.7
repetition_penalty: 1.1
no_repeat_ngram_size: 3
"""
max_upload_bytes: int = 10 * 1024 * 1024 # 10 MB
decode_strategy: str = "greedy"
beam_width: int = 3
length_penalty: float = 1.0
repetition_penalty: float = 1.0
no_repeat_ngram_size: int = 0
cors_allowed_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"])
@field_validator("decode_strategy")
@classmethod
def _validate_decode_strategy(cls, v: str) -> str:
if v not in {"greedy", "beam"}:
raise ValueError(f"decode_strategy must be 'greedy' or 'beam', got {v!r}")
return v
@field_validator("beam_width")
@classmethod
def _validate_beam_width(cls, v: int) -> int:
if v < 1:
raise ValueError(f"beam_width must be >= 1, got {v}")
return v
@field_validator("repetition_penalty")
@classmethod
def _validate_repetition_penalty(cls, v: float) -> float:
if v < 1.0:
raise ValueError(f"repetition_penalty must be >= 1.0 (1.0 disables it), got {v}")
return v
class AppConfig(BaseSettings):
"""Top-level config aggregating every sub-config.
Loaded by ``captioning.config.loader.load_config(yaml_path)``. Env vars
with prefix ``CAPTIONING__`` override fields at any depth.
"""
data: DataConfig = Field(default_factory=DataConfig)
model: ModelConfig = Field(default_factory=ModelConfig)
train: TrainConfig = Field(default_factory=TrainConfig)
serve: ServeConfig = Field(default_factory=ServeConfig)
model_config = SettingsConfigDict(
env_prefix="CAPTIONING__",
env_nested_delimiter="__",
case_sensitive=False,
extra="forbid", # Reject unknown keys — catches typos at load time
)