"""Unit tests for the stratified LHS sampler (initial baseline-surrogate schema).""" from __future__ import annotations from collections import Counter import pytest from roverdevkit.schema import DesignVector, MissionScenario from roverdevkit.surrogate.sampling import ( _CONTINUOUS_DESIGN_BOUNDS, _GROUSER_COUNT_BOUNDS, _SOIL_BOUNDS, FAMILIES, LHSSample, generate_samples, ) # --------------------------------------------------------------------------- # Basic shape / typing # --------------------------------------------------------------------------- def test_generate_samples_total_count_matches_contract() -> None: samples = generate_samples(n_per_scenario=8, seed=0) assert len(samples) == 8 * len(FAMILIES) def test_generate_samples_subset_of_families() -> None: samples = generate_samples( n_per_scenario=4, seed=0, scenario_names=["equatorial_mare_traverse", "polar_prospecting"], ) assert len(samples) == 8 families = {s.scenario_family for s in samples} assert families == {"equatorial_mare_traverse", "polar_prospecting"} def test_generate_samples_rejects_unknown_family() -> None: with pytest.raises(KeyError, match="unknown scenario family"): generate_samples(n_per_scenario=2, scenario_names=["no_such_scenario"]) def test_generate_samples_rejects_odd_n() -> None: with pytest.raises(ValueError, match="even"): generate_samples(n_per_scenario=7) def test_generate_samples_rejects_nonpositive_n() -> None: with pytest.raises(ValueError, match="positive"): generate_samples(n_per_scenario=0) def test_sample_objects_are_typed() -> None: samples = generate_samples(n_per_scenario=2, seed=0) assert all(isinstance(s, LHSSample) for s in samples) assert all(isinstance(s.design, DesignVector) for s in samples) assert all(isinstance(s.scenario, MissionScenario) for s in samples) # --------------------------------------------------------------------------- # Determinism # --------------------------------------------------------------------------- def test_generate_samples_is_deterministic_for_same_seed() -> None: s1 = generate_samples(n_per_scenario=8, seed=123) s2 = generate_samples(n_per_scenario=8, seed=123) for a, b in zip(s1, s2, strict=True): assert a.design == b.design assert a.scenario == b.scenario assert a.soil == b.soil assert a.split == b.split assert a.sample_index == b.sample_index def test_different_seeds_produce_different_draws() -> None: s1 = generate_samples(n_per_scenario=8, seed=123) s2 = generate_samples(n_per_scenario=8, seed=456) diffs = sum(1 for a, b in zip(s1, s2, strict=True) if a.design != b.design) assert diffs > 0 # --------------------------------------------------------------------------- # Stratification # --------------------------------------------------------------------------- def test_wheel_strata_are_exact_50_50() -> None: samples = generate_samples(n_per_scenario=20, seed=7) for family_name in FAMILIES: fam = [s for s in samples if s.scenario_family == family_name] counts = Counter(s.design.n_wheels for s in fam) assert counts[4] == 10, f"{family_name}: 4-wheel count {counts[4]} != 10" assert counts[6] == 10, f"{family_name}: 6-wheel count {counts[6]} != 10" def test_stratum_id_matches_n_wheels() -> None: samples = generate_samples(n_per_scenario=8, seed=0) for s in samples: expected = 0 if s.design.n_wheels == 4 else 1 assert s.stratum_id == expected # --------------------------------------------------------------------------- # Splits # --------------------------------------------------------------------------- def test_split_labels_are_valid() -> None: samples = generate_samples(n_per_scenario=10, seed=0) labels = {s.split for s in samples} assert labels <= {"train", "val", "test"} def test_split_fractions_roughly_match_request() -> None: samples = generate_samples(n_per_scenario=200, seed=0, val_frac=0.2, test_frac=0.1) counts = Counter(s.split for s in samples) total = len(samples) train_frac = counts["train"] / total val_frac = counts["val"] / total test_frac = counts["test"] / total assert abs(train_frac - 0.7) < 0.05 assert abs(val_frac - 0.2) < 0.05 assert abs(test_frac - 0.1) < 0.05 def test_invalid_split_fractions_rejected() -> None: with pytest.raises(ValueError): generate_samples(n_per_scenario=4, val_frac=-0.1) with pytest.raises(ValueError): generate_samples(n_per_scenario=4, val_frac=0.6, test_frac=0.5) # --------------------------------------------------------------------------- # Bounds / coverage # --------------------------------------------------------------------------- def test_continuous_design_vars_are_within_bounds() -> None: samples = generate_samples(n_per_scenario=100, seed=11) for name, lo, hi in _CONTINUOUS_DESIGN_BOUNDS: values = [getattr(s.design, name) for s in samples] assert min(values) >= lo - 1e-9, name assert max(values) <= hi + 1e-9, name def test_grouser_count_is_integer_in_bounds() -> None: samples = generate_samples(n_per_scenario=100, seed=11) lo, hi = _GROUSER_COUNT_BOUNDS for s in samples: assert isinstance(s.design.grouser_count, int) assert lo <= s.design.grouser_count <= hi def test_soil_parameters_within_bounds() -> None: samples = generate_samples(n_per_scenario=100, seed=11) for col, (lo, hi) in _SOIL_BOUNDS.items(): attr = col[len("soil_") :] values = [getattr(s.soil, attr) for s in samples] assert min(values) >= lo - 1e-9, col assert max(values) <= hi + 1e-9, col def test_scenario_perturbation_stays_within_family_ranges() -> None: samples = generate_samples(n_per_scenario=50, seed=11) for s in samples: fam = FAMILIES[s.scenario_family] assert fam.latitude_range_deg[0] - 1e-9 <= s.scenario.latitude_deg assert s.scenario.latitude_deg <= fam.latitude_range_deg[1] + 1e-9 assert fam.mission_duration_range_days[0] - 1e-9 <= s.scenario.mission_duration_earth_days assert s.scenario.mission_duration_earth_days <= fam.mission_duration_range_days[1] + 1e-9 assert fam.max_slope_range_deg[0] - 1e-9 <= s.scenario.max_slope_deg assert s.scenario.max_slope_deg <= fam.max_slope_range_deg[1] + 1e-9 assert s.scenario.terrain_class == fam.terrain_class assert s.scenario.soil_simulant == fam.soil_simulant assert s.scenario.sun_geometry == fam.sun_geometry def test_lhs_covers_design_space_broadly() -> None: """A crude coverage sanity check: with 400 samples, the min/max of each continuous column should cover at least 80% of the bound range.""" samples = generate_samples(n_per_scenario=100, seed=3) for name, lo, hi in _CONTINUOUS_DESIGN_BOUNDS: values = [getattr(s.design, name) for s in samples] span = hi - lo realised = max(values) - min(values) assert realised / span > 0.8, f"{name}: coverage {realised / span:.2%}" # --------------------------------------------------------------------------- # Sample indexing # --------------------------------------------------------------------------- def test_sample_indices_are_dense_and_ordered() -> None: samples = generate_samples(n_per_scenario=6, seed=0) assert [s.sample_index for s in samples] == list(range(len(samples)))