"""Tests for the Pydantic config schema and YAML loader.""" from __future__ import annotations from pathlib import Path import pytest from pydantic import ValidationError from captioning.config.loader import load_config from captioning.config.schema import AppConfig, DataConfig, ModelConfig, TrainConfig def test_defaults_match_notebook_hyperparams() -> None: """The defaults *are* the IEEE notebook's hyperparameters; if anyone changes them by accident, this test fails loudly.""" cfg = AppConfig() assert cfg.model.embedding_dim == 512 assert cfg.model.units == 512 assert cfg.model.max_length == 40 assert cfg.model.vocabulary_size == 15_000 assert cfg.model.encoder_num_heads == 1 assert cfg.model.decoder_num_heads == 8 assert cfg.train.epochs == 10 assert cfg.train.batch_size == 64 assert cfg.train.buffer_size == 1_000 assert cfg.train.early_stopping_patience == 3 assert cfg.data.sample_size == 120_000 assert cfg.data.train_val_split == 0.8 def test_split_validation_rejects_invalid_fractions() -> None: with pytest.raises(ValidationError): DataConfig(train_val_split=0.0) with pytest.raises(ValidationError): DataConfig(train_val_split=1.0) with pytest.raises(ValidationError): DataConfig(train_val_split=1.5) def test_extra_keys_rejected() -> None: """``extra="forbid"`` catches typos at load time instead of training time.""" with pytest.raises(ValidationError): AppConfig(model={"embedding_dim": 512, "tpyo": True}) # type: ignore[arg-type] def test_env_override(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("CAPTIONING__TRAIN__BATCH_SIZE", "32") cfg = AppConfig() assert cfg.train.batch_size == 32 def test_load_config_yaml(tmp_path: Path) -> None: yaml_text = """ data: sample_size: 1000 model: embedding_dim: 256 train: epochs: 2 batch_size: 8 """ p = tmp_path / "test.yaml" p.write_text(yaml_text, encoding="utf-8") cfg = load_config(p) assert cfg.data.sample_size == 1000 assert cfg.model.embedding_dim == 256 assert cfg.train.epochs == 2 # Unspecified fields take defaults assert cfg.model.max_length == 40 def test_load_config_missing_file(tmp_path: Path) -> None: with pytest.raises(FileNotFoundError): load_config(tmp_path / "does-not-exist.yaml") def test_train_seed_default_is_42() -> None: """The notebook didn't seed; we did. 42 is the project default.""" assert TrainConfig().seed == 42 def test_modelconfig_independent_of_other_sections() -> None: """Sub-configs should be constructible without the parent.""" m = ModelConfig(embedding_dim=128, vocabulary_size=500) assert m.embedding_dim == 128 assert m.vocabulary_size == 500 # Defaults preserved assert m.max_length == 40 # ---- Opt-in stability flags ------------------------------------------------ def test_train_stability_defaults_preserve_notebook_parity() -> None: t = TrainConfig() assert t.label_smoothing == 0.0 assert t.lr_schedule == "constant" assert t.warmup_steps == 0 assert t.honour_training_flag_in_test_step is False def test_label_smoothing_rejects_out_of_range() -> None: with pytest.raises(ValidationError): TrainConfig(label_smoothing=1.0) with pytest.raises(ValidationError): TrainConfig(label_smoothing=-0.1) def test_lr_schedule_rejects_unknown() -> None: with pytest.raises(ValidationError): TrainConfig(lr_schedule="square_wave") def test_decode_strategy_validates() -> None: from captioning.config.schema import ServeConfig with pytest.raises(ValidationError): ServeConfig(decode_strategy="nucleus") s = ServeConfig(decode_strategy="beam", beam_width=4) assert s.beam_width == 4 def test_beam_width_and_repetition_penalty_rejected_out_of_range() -> None: from captioning.config.schema import ServeConfig with pytest.raises(ValidationError): ServeConfig(beam_width=0) with pytest.raises(ValidationError): ServeConfig(repetition_penalty=0.5)