Spaces:
Configuration error
Configuration error
File size: 4,122 Bytes
3a2e5f0 91a1214 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """Tests for the Pydantic config schema and YAML loader."""
from __future__ import annotations
from pathlib import Path
import pytest
from pydantic import ValidationError
from captioning.config.loader import load_config
from captioning.config.schema import AppConfig, DataConfig, ModelConfig, TrainConfig
def test_defaults_match_notebook_hyperparams() -> None:
"""The defaults *are* the IEEE notebook's hyperparameters; if anyone
changes them by accident, this test fails loudly."""
cfg = AppConfig()
assert cfg.model.embedding_dim == 512
assert cfg.model.units == 512
assert cfg.model.max_length == 40
assert cfg.model.vocabulary_size == 15_000
assert cfg.model.encoder_num_heads == 1
assert cfg.model.decoder_num_heads == 8
assert cfg.train.epochs == 10
assert cfg.train.batch_size == 64
assert cfg.train.buffer_size == 1_000
assert cfg.train.early_stopping_patience == 3
assert cfg.data.sample_size == 120_000
assert cfg.data.train_val_split == 0.8
def test_split_validation_rejects_invalid_fractions() -> None:
with pytest.raises(ValidationError):
DataConfig(train_val_split=0.0)
with pytest.raises(ValidationError):
DataConfig(train_val_split=1.0)
with pytest.raises(ValidationError):
DataConfig(train_val_split=1.5)
def test_extra_keys_rejected() -> None:
"""``extra="forbid"`` catches typos at load time instead of training time."""
with pytest.raises(ValidationError):
AppConfig(model={"embedding_dim": 512, "tpyo": True}) # type: ignore[arg-type]
def test_env_override(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CAPTIONING__TRAIN__BATCH_SIZE", "32")
cfg = AppConfig()
assert cfg.train.batch_size == 32
def test_load_config_yaml(tmp_path: Path) -> None:
yaml_text = """
data:
sample_size: 1000
model:
embedding_dim: 256
train:
epochs: 2
batch_size: 8
"""
p = tmp_path / "test.yaml"
p.write_text(yaml_text, encoding="utf-8")
cfg = load_config(p)
assert cfg.data.sample_size == 1000
assert cfg.model.embedding_dim == 256
assert cfg.train.epochs == 2
# Unspecified fields take defaults
assert cfg.model.max_length == 40
def test_load_config_missing_file(tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError):
load_config(tmp_path / "does-not-exist.yaml")
def test_train_seed_default_is_42() -> None:
"""The notebook didn't seed; we did. 42 is the project default."""
assert TrainConfig().seed == 42
def test_modelconfig_independent_of_other_sections() -> None:
"""Sub-configs should be constructible without the parent."""
m = ModelConfig(embedding_dim=128, vocabulary_size=500)
assert m.embedding_dim == 128
assert m.vocabulary_size == 500
# Defaults preserved
assert m.max_length == 40
# ---- Opt-in stability flags ------------------------------------------------
def test_train_stability_defaults_preserve_notebook_parity() -> None:
t = TrainConfig()
assert t.label_smoothing == 0.0
assert t.lr_schedule == "constant"
assert t.warmup_steps == 0
assert t.honour_training_flag_in_test_step is False
def test_label_smoothing_rejects_out_of_range() -> None:
with pytest.raises(ValidationError):
TrainConfig(label_smoothing=1.0)
with pytest.raises(ValidationError):
TrainConfig(label_smoothing=-0.1)
def test_lr_schedule_rejects_unknown() -> None:
with pytest.raises(ValidationError):
TrainConfig(lr_schedule="square_wave")
def test_decode_strategy_validates() -> None:
from captioning.config.schema import ServeConfig
with pytest.raises(ValidationError):
ServeConfig(decode_strategy="nucleus")
s = ServeConfig(decode_strategy="beam", beam_width=4)
assert s.beam_width == 4
def test_beam_width_and_repetition_penalty_rejected_out_of_range() -> None:
from captioning.config.schema import ServeConfig
with pytest.raises(ValidationError):
ServeConfig(beam_width=0)
with pytest.raises(ValidationError):
ServeConfig(repetition_penalty=0.5)
|