prefero / scripts /test_cv.py
Wil2200's picture
Add HTML report, cross-validation, admin analytics, auto-adaptive data, Slowbro pill fix
c62aef1
"""Cross-validation module tests.
Run from project root:
python3 scripts/test_cv.py
"""
from __future__ import annotations
import sys
import traceback
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
import numpy as np
import torch
from dce_analyzer.simulate import generate_simulated_dce
from dce_analyzer.config import FullModelSpec, VariableSpec
from dce_analyzer.cross_validation import cross_validate, CrossValidationResult
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_results: list[tuple[str, bool, str]] = []
def _run(name: str, fn):
"""Run *fn* and record PASS / FAIL."""
try:
fn()
_results.append((name, True, ""))
print(f" PASS {name}")
except Exception as exc:
msg = f"{exc.__class__.__name__}: {exc}"
_results.append((name, False, msg))
print(f" FAIL {name}")
traceback.print_exc()
print()
# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------
sim = generate_simulated_dce(n_individuals=30, n_tasks=4, n_alts=3, seed=42)
DF = sim.data
VARS_FIXED = [
VariableSpec(name="price", column="price", distribution="fixed"),
VariableSpec(name="time", column="time", distribution="fixed"),
VariableSpec(name="comfort", column="comfort", distribution="fixed"),
]
SPEC_CL = FullModelSpec(
id_col="respondent_id",
task_col="task_id",
alt_col="alternative",
choice_col="choice",
variables=VARS_FIXED,
model_type="conditional",
maxiter=100,
)
CPU = torch.device("cpu")
# Store the CL result here so later tests can reference it
_cl_result: CrossValidationResult | None = None
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_kfold_split_preserves_all_individuals():
"""1. K-fold split preserves all individuals."""
unique_ids = DF["respondent_id"].unique()
rng = np.random.default_rng(42)
ids_copy = unique_ids.copy()
rng.shuffle(ids_copy)
k = 5
folds = np.array_split(ids_copy, k)
# Union of all folds should equal the full set
union = set()
for fold in folds:
union.update(fold.tolist())
assert union == set(unique_ids), (
f"Union of folds ({len(union)}) != all IDs ({len(unique_ids)})"
)
_run("1. K-fold split preserves all individuals", test_kfold_split_preserves_all_individuals)
def test_no_individual_in_both_train_and_test():
"""2. No individual appears in both train and test for any fold."""
unique_ids = DF["respondent_id"].unique()
rng = np.random.default_rng(42)
ids_copy = unique_ids.copy()
rng.shuffle(ids_copy)
k = 5
folds = np.array_split(ids_copy, k)
for fold_idx in range(k):
test_ids = set(folds[fold_idx].tolist())
train_ids = set(unique_ids) - test_ids
overlap = test_ids & train_ids
assert len(overlap) == 0, (
f"Fold {fold_idx}: overlap={overlap}"
)
_run("2. No individual in both train and test", test_no_individual_in_both_train_and_test)
def test_cv_conditional_logit():
"""3. CV with Conditional Logit (3-fold)."""
global _cl_result
result = cross_validate(DF, SPEC_CL, k=3, seed=42, device=CPU)
_cl_result = result
assert isinstance(result, CrossValidationResult), "Wrong return type"
assert result.k == 3, f"Expected k=3, got k={result.k}"
assert len(result.fold_results) == 3, (
f"Expected 3 fold results, got {len(result.fold_results)}"
)
assert result.mean_test_ll < 0, (
f"Expected negative mean test LL, got {result.mean_test_ll}"
)
assert result.model_type == "conditional"
assert result.total_runtime > 0
_run("3. CV with Conditional Logit (3-fold)", test_cv_conditional_logit)
def test_cv_mixed_logit():
"""4. CV with Mixed Logit (3-fold)."""
vars_random = [
VariableSpec(name="price", column="price", distribution="normal"),
VariableSpec(name="time", column="time", distribution="normal"),
VariableSpec(name="comfort", column="comfort", distribution="fixed"),
]
spec_mxl = FullModelSpec(
id_col="respondent_id",
task_col="task_id",
alt_col="alternative",
choice_col="choice",
variables=vars_random,
model_type="mixed",
n_draws=50,
maxiter=50,
)
result = cross_validate(DF, spec_mxl, k=3, seed=42, device=CPU)
assert isinstance(result, CrossValidationResult)
assert result.k == 3
assert len(result.fold_results) == 3
assert result.model_type == "mixed"
_run("4. CV with Mixed Logit (3-fold)", test_cv_mixed_logit)
def test_hit_rate_bounds():
"""5. Hit rate is between 0 and 1."""
assert _cl_result is not None, "Test 3 must pass first (CL result needed)"
for fr in _cl_result.fold_results:
assert 0.0 <= fr.hit_rate <= 1.0, f"Fold {fr.fold}: hit_rate={fr.hit_rate}"
assert 0.0 <= _cl_result.mean_hit_rate <= 1.0, (
f"mean_hit_rate={_cl_result.mean_hit_rate}"
)
_run("5. Hit rate is between 0 and 1", test_hit_rate_bounds)
def test_k_greater_than_n_individuals_raises():
"""6. K > n_individuals raises an error."""
sim_small = generate_simulated_dce(n_individuals=10, n_tasks=4, n_alts=3, seed=99)
df_small = sim_small.data
spec_small = FullModelSpec(
id_col="respondent_id",
task_col="task_id",
alt_col="alternative",
choice_col="choice",
variables=VARS_FIXED,
model_type="conditional",
maxiter=50,
)
raised = False
try:
cross_validate(df_small, spec_small, k=100, seed=42, device=CPU)
except ValueError as e:
raised = True
assert "k=100" in str(e), f"Unexpected error message: {e}"
assert raised, "Expected ValueError when K > n_individuals"
_run("6. K > n_individuals raises error", test_k_greater_than_n_individuals_raises)
def test_progress_callback_called():
"""7. Progress callback is called K times."""
calls = []
def callback(fold_idx, k, status):
calls.append((fold_idx, k, status))
k = 3
cross_validate(DF, SPEC_CL, k=k, seed=42, device=CPU, progress_callback=callback)
assert len(calls) == k, f"Expected {k} callback calls, got {len(calls)}"
for i, (fold_idx, k_val, status) in enumerate(calls):
assert fold_idx == i, f"Expected fold_idx={i}, got {fold_idx}"
assert k_val == k, f"Expected k={k}, got {k_val}"
_run("7. Progress callback is called K times", test_progress_callback_called)
# ---------------------------------------------------------------------------
# Summary
# ---------------------------------------------------------------------------
def _summary():
total = len(_results)
passed = sum(1 for _, ok, _ in _results if ok)
failed = total - passed
print(f"\n{'='*60}")
print(f" {passed} passed, {failed} failed out of {total} tests")
print(f"{'='*60}")
if failed == 0:
print(" ALL TESTS PASSED")
else:
print("\n Failed tests:")
for name, ok, msg in _results:
if not ok:
print(f" - {name}: {msg}")
return failed == 0
if __name__ == "__main__":
success = _summary()
sys.exit(0 if success else 1)