hackathon_code4change / tests /test_enhancements.py
RoyAalekh's picture
refactored project structure. renamed scheduler dir to src
6a28f91
"""Test script to validate all merged enhancements are properly parameterized.
Tests the following merged PRs:
- PR #2: Override handling (state pollution fix)
- PR #3: Ripeness UNKNOWN state
- PR #6: Parameter fallback with bundled defaults
- PR #4: RL training with SchedulingAlgorithm constraints
- PR #5: Shared reward helper
- PR #7: Output metadata tracking
"""
import sys
from datetime import date, datetime
from pathlib import Path
# Test configurations
TESTS_PASSED = []
TESTS_FAILED = []
def log_test(name: str, passed: bool, details: str = ""):
"""Log test result."""
if passed:
TESTS_PASSED.append(name)
print(f"[PASS] {name}")
if details:
print(f" {details}")
else:
TESTS_FAILED.append(name)
print(f"[FAIL] {name}")
if details:
print(f" {details}")
def test_pr2_override_validation():
"""Test PR #2: Override validation preserves original list and tracks rejections."""
from src.control.overrides import Override, OverrideType
from src.core.algorithm import SchedulingAlgorithm
from src.core.courtroom import Courtroom
from src.data.case_generator import CaseGenerator
from src.simulation.allocator import CourtroomAllocator
from src.simulation.policies.readiness import ReadinessPolicy
try:
# Generate test cases
gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42)
cases = gen.generate(50)
# Create test overrides (some valid, some invalid)
test_overrides = [
Override(
override_id="test-1",
override_type=OverrideType.PRIORITY,
case_id=cases[0].case_id,
judge_id="TEST-JUDGE",
timestamp=datetime.now(),
new_priority=0.95,
),
Override(
override_id="test-2",
override_type=OverrideType.PRIORITY,
case_id="INVALID-CASE-ID", # Invalid case
judge_id="TEST-JUDGE",
timestamp=datetime.now(),
new_priority=0.85,
),
]
original_count = len(test_overrides)
# Setup algorithm
courtrooms = [Courtroom(courtroom_id=1, judge_id="J001", daily_capacity=50)]
allocator = CourtroomAllocator(num_courtrooms=1, per_courtroom_capacity=50)
policy = ReadinessPolicy()
algorithm = SchedulingAlgorithm(policy=policy, allocator=allocator)
# Run scheduling with overrides
result = algorithm.schedule_day(
cases=cases,
courtrooms=courtrooms,
current_date=date(2024, 1, 15),
overrides=test_overrides,
)
# Verify original list unchanged
assert len(test_overrides) == original_count, (
"Original override list was mutated"
)
# Verify rejection tracking exists (even if empty for valid overrides)
assert hasattr(result, "override_rejections"), "No override_rejections field"
# Verify applied overrides tracked
assert hasattr(result, "applied_overrides"), "No applied_overrides field"
log_test(
"PR #2: Override validation",
True,
f"Applied: {len(result.applied_overrides)}, Rejected: {len(result.override_rejections)}",
)
return True
except Exception as e:
log_test("PR #2: Override validation", False, str(e))
return False
def test_pr2_flag_cleanup():
"""Test PR #2: Temporary case flags are cleared after scheduling."""
from src.core.algorithm import SchedulingAlgorithm
from src.core.courtroom import Courtroom
from src.data.case_generator import CaseGenerator
from src.simulation.allocator import CourtroomAllocator
from src.simulation.policies.readiness import ReadinessPolicy
try:
gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42)
cases = gen.generate(10)
# Set priority override flag
test_case = cases[0]
test_case._priority_override = 0.99
# Run scheduling
courtrooms = [Courtroom(courtroom_id=1, judge_id="J001", daily_capacity=50)]
allocator = CourtroomAllocator(num_courtrooms=1, per_courtroom_capacity=50)
policy = ReadinessPolicy()
algorithm = SchedulingAlgorithm(policy=policy, allocator=allocator)
algorithm.schedule_day(cases, courtrooms, date(2024, 1, 15))
# Verify flag cleared
assert (
not hasattr(test_case, "_priority_override")
or test_case._priority_override is None
), "Priority override flag not cleared"
log_test(
"PR #2: Flag cleanup", True, "Temporary flags cleared after scheduling"
)
return True
except Exception as e:
log_test("PR #2: Flag cleanup", False, str(e))
return False
def test_pr3_unknown_ripeness():
"""Test PR #3: UNKNOWN ripeness status exists and is used."""
from src.core.ripeness import RipenessClassifier, RipenessStatus
from src.data.case_generator import CaseGenerator
try:
# Verify UNKNOWN status exists
assert hasattr(RipenessStatus, "UNKNOWN"), "RipenessStatus.UNKNOWN not found"
# Create case with ambiguous ripeness
gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42)
cases = gen.generate(10)
# Clear ripeness indicators to test UNKNOWN default
test_case = cases[0]
test_case.last_hearing_date = None
test_case.service_status = None
test_case.compliance_status = None
# Classify ripeness
ripeness = RipenessClassifier.classify(test_case, date(2024, 1, 15))
# Should default to UNKNOWN when no evidence
assert ripeness == RipenessStatus.UNKNOWN or not ripeness.is_ripe(), (
"Ambiguous case did not get UNKNOWN or non-RIPE status"
)
log_test("PR #3: UNKNOWN ripeness", True, f"Status: {ripeness.value}")
return True
except Exception as e:
log_test("PR #3: UNKNOWN ripeness", False, str(e))
return False
def test_pr6_parameter_fallback():
"""Test PR #6: Parameter fallback with bundled defaults."""
try:
# Test that defaults directory exists
defaults_dir = Path("scheduler/data/defaults")
assert defaults_dir.exists(), f"Defaults directory not found: {defaults_dir}"
# Check for expected default files
expected_files = [
"stage_transition_probs.csv",
"stage_duration.csv",
"adjournment_proxies.csv",
"court_capacity_global.json",
"stage_transition_entropy.csv",
"case_type_summary.csv",
]
for file in expected_files:
file_path = defaults_dir / file
assert file_path.exists(), f"Default file missing: {file}"
log_test(
"PR #6: Parameter fallback",
True,
f"Found {len(expected_files)} default parameter files",
)
return True
except Exception as e:
log_test("PR #6: Parameter fallback", False, str(e))
return False
def test_pr4_rl_constraints():
"""Test PR #4: RL training uses SchedulingAlgorithm with constraints."""
from rl.config import RLTrainingConfig
from rl.training import RLTrainingEnvironment
from src.data.case_generator import CaseGenerator
try:
# Create training environment
gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42)
cases = gen.generate(100)
config = RLTrainingConfig(
episodes=2,
cases_per_episode=100,
episode_length_days=10,
courtrooms=2,
daily_capacity_per_courtroom=50,
enforce_min_gap=True,
cap_daily_allocations=True,
apply_judge_preferences=True,
)
env = RLTrainingEnvironment(
cases=cases, start_date=date(2024, 1, 1), horizon_days=10, rl_config=config
)
# Verify SchedulingAlgorithm components exist
assert hasattr(env, "algorithm"), (
"No SchedulingAlgorithm in training environment"
)
assert hasattr(env, "courtrooms"), "No courtrooms in training environment"
assert hasattr(env, "allocator"), "No allocator in training environment"
assert hasattr(env, "policy"), "No policy in training environment"
# Test step with agent decisions
agent_decisions = {cases[0].case_id: 1, cases[1].case_id: 1}
updated_cases, rewards, done = env.step(agent_decisions)
assert len(rewards) >= 0, "No rewards returned from step"
log_test(
"PR #4: RL constraints",
True,
f"Environment has algorithm, courtrooms, allocator. Capacity enforced: {config.cap_daily_allocations}",
)
return True
except Exception as e:
log_test("PR #4: RL constraints", False, str(e))
return False
def test_pr5_shared_rewards():
"""Test PR #5: Shared reward helper exists and is used."""
from rl.rewards import EpisodeRewardHelper
from rl.training import RLTrainingEnvironment
from src.data.case_generator import CaseGenerator
try:
# Verify EpisodeRewardHelper exists
helper = EpisodeRewardHelper(total_cases=100)
assert hasattr(helper, "compute_case_reward"), "No compute_case_reward method"
# Verify training environment uses it
gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42)
cases = gen.generate(50)
env = RLTrainingEnvironment(cases, date(2024, 1, 1), 10)
assert hasattr(env, "reward_helper"), (
"Training environment doesn't use reward_helper"
)
assert isinstance(env.reward_helper, EpisodeRewardHelper), (
"reward_helper is not EpisodeRewardHelper instance"
)
# Test reward computation
test_case = cases[0]
reward = env.reward_helper.compute_case_reward(
case=test_case,
was_scheduled=True,
hearing_outcome="PROGRESS",
current_date=date(2024, 1, 15),
previous_gap_days=30,
)
assert isinstance(reward, float), "Reward is not a float"
log_test(
"PR #5: Shared rewards",
True,
f"Helper integrated, sample reward: {reward:.2f}",
)
return True
except Exception as e:
log_test("PR #5: Shared rewards", False, str(e))
return False
def test_pr7_metadata_tracking():
"""Test PR #7: Output metadata tracking."""
from src.utils.output_manager import OutputManager
try:
# Create output manager
output = OutputManager(run_id="test_run")
output.create_structure()
# Verify metadata methods exist
assert hasattr(output, "record_eda_metadata"), "No record_eda_metadata method"
assert hasattr(output, "save_training_stats"), "No save_training_stats method"
assert hasattr(output, "save_evaluation_stats"), (
"No save_evaluation_stats method"
)
assert hasattr(output, "record_simulation_kpis"), (
"No record_simulation_kpis method"
)
# Verify run_record file created
assert output.run_record_file.exists(), "run_record.json not created"
# Test metadata recording
output.record_eda_metadata(
version="test_v1",
used_cached=False,
params_path=Path("test_params"),
figures_path=Path("test_figures"),
)
# Verify metadata was written
import json
with open(output.run_record_file, "r") as f:
record = json.load(f)
assert "sections" in record, "No sections in run_record"
assert "eda" in record["sections"], "EDA metadata not recorded"
log_test(
"PR #7: Metadata tracking",
True,
f"Run record created with {len(record['sections'])} sections",
)
return True
except Exception as e:
log_test("PR #7: Metadata tracking", False, str(e))
return False
def run_all_tests():
"""Run all enhancement tests."""
print("=" * 60)
print("Testing Merged Enhancements")
print("=" * 60)
print()
# PR #2 tests
print("PR #2: Override Handling Refactor")
print("-" * 40)
test_pr2_override_validation()
test_pr2_flag_cleanup()
print()
# PR #3 tests
print("PR #3: Ripeness UNKNOWN State")
print("-" * 40)
test_pr3_unknown_ripeness()
print()
# PR #6 tests
print("PR #6: Parameter Fallback")
print("-" * 40)
test_pr6_parameter_fallback()
print()
# PR #4 tests
print("PR #4: RL Training Alignment")
print("-" * 40)
test_pr4_rl_constraints()
print()
# PR #5 tests
print("PR #5: Shared Reward Helper")
print("-" * 40)
test_pr5_shared_rewards()
print()
# PR #7 tests
print("PR #7: Output Metadata Tracking")
print("-" * 40)
test_pr7_metadata_tracking()
print()
# Summary
print("=" * 60)
print("Test Summary")
print("=" * 60)
print(f"Passed: {len(TESTS_PASSED)}")
print(f"Failed: {len(TESTS_FAILED)}")
print()
if TESTS_FAILED:
print("Failed tests:")
for test in TESTS_FAILED:
print(f" - {test}")
return 1
else:
print("All tests passed!")
return 0
if __name__ == "__main__":
sys.exit(run_all_tests())