File size: 4,122 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Tests for the Pydantic config schema and YAML loader."""

from __future__ import annotations

from pathlib import Path

import pytest
from pydantic import ValidationError

from captioning.config.loader import load_config
from captioning.config.schema import AppConfig, DataConfig, ModelConfig, TrainConfig


def test_defaults_match_notebook_hyperparams() -> None:
    """The defaults *are* the IEEE notebook's hyperparameters; if anyone
    changes them by accident, this test fails loudly."""
    cfg = AppConfig()
    assert cfg.model.embedding_dim == 512
    assert cfg.model.units == 512
    assert cfg.model.max_length == 40
    assert cfg.model.vocabulary_size == 15_000
    assert cfg.model.encoder_num_heads == 1
    assert cfg.model.decoder_num_heads == 8
    assert cfg.train.epochs == 10
    assert cfg.train.batch_size == 64
    assert cfg.train.buffer_size == 1_000
    assert cfg.train.early_stopping_patience == 3
    assert cfg.data.sample_size == 120_000
    assert cfg.data.train_val_split == 0.8


def test_split_validation_rejects_invalid_fractions() -> None:
    with pytest.raises(ValidationError):
        DataConfig(train_val_split=0.0)
    with pytest.raises(ValidationError):
        DataConfig(train_val_split=1.0)
    with pytest.raises(ValidationError):
        DataConfig(train_val_split=1.5)


def test_extra_keys_rejected() -> None:
    """``extra="forbid"`` catches typos at load time instead of training time."""
    with pytest.raises(ValidationError):
        AppConfig(model={"embedding_dim": 512, "tpyo": True})  # type: ignore[arg-type]


def test_env_override(monkeypatch: pytest.MonkeyPatch) -> None:
    monkeypatch.setenv("CAPTIONING__TRAIN__BATCH_SIZE", "32")
    cfg = AppConfig()
    assert cfg.train.batch_size == 32


def test_load_config_yaml(tmp_path: Path) -> None:
    yaml_text = """
data:
  sample_size: 1000
model:
  embedding_dim: 256
train:
  epochs: 2
  batch_size: 8
"""
    p = tmp_path / "test.yaml"
    p.write_text(yaml_text, encoding="utf-8")
    cfg = load_config(p)
    assert cfg.data.sample_size == 1000
    assert cfg.model.embedding_dim == 256
    assert cfg.train.epochs == 2
    # Unspecified fields take defaults
    assert cfg.model.max_length == 40


def test_load_config_missing_file(tmp_path: Path) -> None:
    with pytest.raises(FileNotFoundError):
        load_config(tmp_path / "does-not-exist.yaml")


def test_train_seed_default_is_42() -> None:
    """The notebook didn't seed; we did. 42 is the project default."""
    assert TrainConfig().seed == 42


def test_modelconfig_independent_of_other_sections() -> None:
    """Sub-configs should be constructible without the parent."""
    m = ModelConfig(embedding_dim=128, vocabulary_size=500)
    assert m.embedding_dim == 128
    assert m.vocabulary_size == 500
    # Defaults preserved
    assert m.max_length == 40


# ---- Opt-in stability flags ------------------------------------------------


def test_train_stability_defaults_preserve_notebook_parity() -> None:
    t = TrainConfig()
    assert t.label_smoothing == 0.0
    assert t.lr_schedule == "constant"
    assert t.warmup_steps == 0
    assert t.honour_training_flag_in_test_step is False


def test_label_smoothing_rejects_out_of_range() -> None:
    with pytest.raises(ValidationError):
        TrainConfig(label_smoothing=1.0)
    with pytest.raises(ValidationError):
        TrainConfig(label_smoothing=-0.1)


def test_lr_schedule_rejects_unknown() -> None:
    with pytest.raises(ValidationError):
        TrainConfig(lr_schedule="square_wave")


def test_decode_strategy_validates() -> None:
    from captioning.config.schema import ServeConfig

    with pytest.raises(ValidationError):
        ServeConfig(decode_strategy="nucleus")
    s = ServeConfig(decode_strategy="beam", beam_width=4)
    assert s.beam_width == 4


def test_beam_width_and_repetition_penalty_rejected_out_of_range() -> None:
    from captioning.config.schema import ServeConfig

    with pytest.raises(ValidationError):
        ServeConfig(beam_width=0)
    with pytest.raises(ValidationError):
        ServeConfig(repetition_penalty=0.5)