roverdevkit / tests /test_surrogate_sampling.py
jjreif's picture
Deploy roverdevkit @ 2676a67
b3d14e3
Raw
History Blame Contribute Delete
7.58 kB
"""Unit tests for the stratified LHS sampler (initial baseline-surrogate schema)."""
from __future__ import annotations
from collections import Counter
import pytest
from roverdevkit.schema import DesignVector, MissionScenario
from roverdevkit.surrogate.sampling import (
_CONTINUOUS_DESIGN_BOUNDS,
_GROUSER_COUNT_BOUNDS,
_SOIL_BOUNDS,
FAMILIES,
LHSSample,
generate_samples,
)
# ---------------------------------------------------------------------------
# Basic shape / typing
# ---------------------------------------------------------------------------
def test_generate_samples_total_count_matches_contract() -> None:
samples = generate_samples(n_per_scenario=8, seed=0)
assert len(samples) == 8 * len(FAMILIES)
def test_generate_samples_subset_of_families() -> None:
samples = generate_samples(
n_per_scenario=4,
seed=0,
scenario_names=["equatorial_mare_traverse", "polar_prospecting"],
)
assert len(samples) == 8
families = {s.scenario_family for s in samples}
assert families == {"equatorial_mare_traverse", "polar_prospecting"}
def test_generate_samples_rejects_unknown_family() -> None:
with pytest.raises(KeyError, match="unknown scenario family"):
generate_samples(n_per_scenario=2, scenario_names=["no_such_scenario"])
def test_generate_samples_rejects_odd_n() -> None:
with pytest.raises(ValueError, match="even"):
generate_samples(n_per_scenario=7)
def test_generate_samples_rejects_nonpositive_n() -> None:
with pytest.raises(ValueError, match="positive"):
generate_samples(n_per_scenario=0)
def test_sample_objects_are_typed() -> None:
samples = generate_samples(n_per_scenario=2, seed=0)
assert all(isinstance(s, LHSSample) for s in samples)
assert all(isinstance(s.design, DesignVector) for s in samples)
assert all(isinstance(s.scenario, MissionScenario) for s in samples)
# ---------------------------------------------------------------------------
# Determinism
# ---------------------------------------------------------------------------
def test_generate_samples_is_deterministic_for_same_seed() -> None:
s1 = generate_samples(n_per_scenario=8, seed=123)
s2 = generate_samples(n_per_scenario=8, seed=123)
for a, b in zip(s1, s2, strict=True):
assert a.design == b.design
assert a.scenario == b.scenario
assert a.soil == b.soil
assert a.split == b.split
assert a.sample_index == b.sample_index
def test_different_seeds_produce_different_draws() -> None:
s1 = generate_samples(n_per_scenario=8, seed=123)
s2 = generate_samples(n_per_scenario=8, seed=456)
diffs = sum(1 for a, b in zip(s1, s2, strict=True) if a.design != b.design)
assert diffs > 0
# ---------------------------------------------------------------------------
# Stratification
# ---------------------------------------------------------------------------
def test_wheel_strata_are_exact_50_50() -> None:
samples = generate_samples(n_per_scenario=20, seed=7)
for family_name in FAMILIES:
fam = [s for s in samples if s.scenario_family == family_name]
counts = Counter(s.design.n_wheels for s in fam)
assert counts[4] == 10, f"{family_name}: 4-wheel count {counts[4]} != 10"
assert counts[6] == 10, f"{family_name}: 6-wheel count {counts[6]} != 10"
def test_stratum_id_matches_n_wheels() -> None:
samples = generate_samples(n_per_scenario=8, seed=0)
for s in samples:
expected = 0 if s.design.n_wheels == 4 else 1
assert s.stratum_id == expected
# ---------------------------------------------------------------------------
# Splits
# ---------------------------------------------------------------------------
def test_split_labels_are_valid() -> None:
samples = generate_samples(n_per_scenario=10, seed=0)
labels = {s.split for s in samples}
assert labels <= {"train", "val", "test"}
def test_split_fractions_roughly_match_request() -> None:
samples = generate_samples(n_per_scenario=200, seed=0, val_frac=0.2, test_frac=0.1)
counts = Counter(s.split for s in samples)
total = len(samples)
train_frac = counts["train"] / total
val_frac = counts["val"] / total
test_frac = counts["test"] / total
assert abs(train_frac - 0.7) < 0.05
assert abs(val_frac - 0.2) < 0.05
assert abs(test_frac - 0.1) < 0.05
def test_invalid_split_fractions_rejected() -> None:
with pytest.raises(ValueError):
generate_samples(n_per_scenario=4, val_frac=-0.1)
with pytest.raises(ValueError):
generate_samples(n_per_scenario=4, val_frac=0.6, test_frac=0.5)
# ---------------------------------------------------------------------------
# Bounds / coverage
# ---------------------------------------------------------------------------
def test_continuous_design_vars_are_within_bounds() -> None:
samples = generate_samples(n_per_scenario=100, seed=11)
for name, lo, hi in _CONTINUOUS_DESIGN_BOUNDS:
values = [getattr(s.design, name) for s in samples]
assert min(values) >= lo - 1e-9, name
assert max(values) <= hi + 1e-9, name
def test_grouser_count_is_integer_in_bounds() -> None:
samples = generate_samples(n_per_scenario=100, seed=11)
lo, hi = _GROUSER_COUNT_BOUNDS
for s in samples:
assert isinstance(s.design.grouser_count, int)
assert lo <= s.design.grouser_count <= hi
def test_soil_parameters_within_bounds() -> None:
samples = generate_samples(n_per_scenario=100, seed=11)
for col, (lo, hi) in _SOIL_BOUNDS.items():
attr = col[len("soil_") :]
values = [getattr(s.soil, attr) for s in samples]
assert min(values) >= lo - 1e-9, col
assert max(values) <= hi + 1e-9, col
def test_scenario_perturbation_stays_within_family_ranges() -> None:
samples = generate_samples(n_per_scenario=50, seed=11)
for s in samples:
fam = FAMILIES[s.scenario_family]
assert fam.latitude_range_deg[0] - 1e-9 <= s.scenario.latitude_deg
assert s.scenario.latitude_deg <= fam.latitude_range_deg[1] + 1e-9
assert fam.mission_duration_range_days[0] - 1e-9 <= s.scenario.mission_duration_earth_days
assert s.scenario.mission_duration_earth_days <= fam.mission_duration_range_days[1] + 1e-9
assert fam.max_slope_range_deg[0] - 1e-9 <= s.scenario.max_slope_deg
assert s.scenario.max_slope_deg <= fam.max_slope_range_deg[1] + 1e-9
assert s.scenario.terrain_class == fam.terrain_class
assert s.scenario.soil_simulant == fam.soil_simulant
assert s.scenario.sun_geometry == fam.sun_geometry
def test_lhs_covers_design_space_broadly() -> None:
"""A crude coverage sanity check: with 400 samples, the min/max of
each continuous column should cover at least 80% of the bound range."""
samples = generate_samples(n_per_scenario=100, seed=3)
for name, lo, hi in _CONTINUOUS_DESIGN_BOUNDS:
values = [getattr(s.design, name) for s in samples]
span = hi - lo
realised = max(values) - min(values)
assert realised / span > 0.8, f"{name}: coverage {realised / span:.2%}"
# ---------------------------------------------------------------------------
# Sample indexing
# ---------------------------------------------------------------------------
def test_sample_indices_are_dense_and_ordered() -> None:
samples = generate_samples(n_per_scenario=6, seed=0)
assert [s.sample_index for s in samples] == list(range(len(samples)))