Spaces:
Sleeping
Sleeping
| """Test script to validate all merged enhancements are properly parameterized. | |
| Tests the following merged PRs: | |
| - PR #2: Override handling (state pollution fix) | |
| - PR #3: Ripeness UNKNOWN state | |
| - PR #6: Parameter fallback with bundled defaults | |
| - PR #4: RL training with SchedulingAlgorithm constraints | |
| - PR #5: Shared reward helper | |
| - PR #7: Output metadata tracking | |
| """ | |
| import sys | |
| from datetime import date, datetime | |
| from pathlib import Path | |
| # Test configurations | |
| TESTS_PASSED = [] | |
| TESTS_FAILED = [] | |
| def log_test(name: str, passed: bool, details: str = ""): | |
| """Log test result.""" | |
| if passed: | |
| TESTS_PASSED.append(name) | |
| print(f"[PASS] {name}") | |
| if details: | |
| print(f" {details}") | |
| else: | |
| TESTS_FAILED.append(name) | |
| print(f"[FAIL] {name}") | |
| if details: | |
| print(f" {details}") | |
| def test_pr2_override_validation(): | |
| """Test PR #2: Override validation preserves original list and tracks rejections.""" | |
| from src.control.overrides import Override, OverrideType | |
| from src.core.algorithm import SchedulingAlgorithm | |
| from src.core.courtroom import Courtroom | |
| from src.data.case_generator import CaseGenerator | |
| from src.simulation.allocator import CourtroomAllocator | |
| from src.simulation.policies.readiness import ReadinessPolicy | |
| try: | |
| # Generate test cases | |
| gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42) | |
| cases = gen.generate(50) | |
| # Create test overrides (some valid, some invalid) | |
| test_overrides = [ | |
| Override( | |
| override_id="test-1", | |
| override_type=OverrideType.PRIORITY, | |
| case_id=cases[0].case_id, | |
| judge_id="TEST-JUDGE", | |
| timestamp=datetime.now(), | |
| new_priority=0.95, | |
| ), | |
| Override( | |
| override_id="test-2", | |
| override_type=OverrideType.PRIORITY, | |
| case_id="INVALID-CASE-ID", # Invalid case | |
| judge_id="TEST-JUDGE", | |
| timestamp=datetime.now(), | |
| new_priority=0.85, | |
| ), | |
| ] | |
| original_count = len(test_overrides) | |
| # Setup algorithm | |
| courtrooms = [Courtroom(courtroom_id=1, judge_id="J001", daily_capacity=50)] | |
| allocator = CourtroomAllocator(num_courtrooms=1, per_courtroom_capacity=50) | |
| policy = ReadinessPolicy() | |
| algorithm = SchedulingAlgorithm(policy=policy, allocator=allocator) | |
| # Run scheduling with overrides | |
| result = algorithm.schedule_day( | |
| cases=cases, | |
| courtrooms=courtrooms, | |
| current_date=date(2024, 1, 15), | |
| overrides=test_overrides, | |
| ) | |
| # Verify original list unchanged | |
| assert len(test_overrides) == original_count, ( | |
| "Original override list was mutated" | |
| ) | |
| # Verify rejection tracking exists (even if empty for valid overrides) | |
| assert hasattr(result, "override_rejections"), "No override_rejections field" | |
| # Verify applied overrides tracked | |
| assert hasattr(result, "applied_overrides"), "No applied_overrides field" | |
| log_test( | |
| "PR #2: Override validation", | |
| True, | |
| f"Applied: {len(result.applied_overrides)}, Rejected: {len(result.override_rejections)}", | |
| ) | |
| return True | |
| except Exception as e: | |
| log_test("PR #2: Override validation", False, str(e)) | |
| return False | |
| def test_pr2_flag_cleanup(): | |
| """Test PR #2: Temporary case flags are cleared after scheduling.""" | |
| from src.core.algorithm import SchedulingAlgorithm | |
| from src.core.courtroom import Courtroom | |
| from src.data.case_generator import CaseGenerator | |
| from src.simulation.allocator import CourtroomAllocator | |
| from src.simulation.policies.readiness import ReadinessPolicy | |
| try: | |
| gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42) | |
| cases = gen.generate(10) | |
| # Set priority override flag | |
| test_case = cases[0] | |
| test_case._priority_override = 0.99 | |
| # Run scheduling | |
| courtrooms = [Courtroom(courtroom_id=1, judge_id="J001", daily_capacity=50)] | |
| allocator = CourtroomAllocator(num_courtrooms=1, per_courtroom_capacity=50) | |
| policy = ReadinessPolicy() | |
| algorithm = SchedulingAlgorithm(policy=policy, allocator=allocator) | |
| algorithm.schedule_day(cases, courtrooms, date(2024, 1, 15)) | |
| # Verify flag cleared | |
| assert ( | |
| not hasattr(test_case, "_priority_override") | |
| or test_case._priority_override is None | |
| ), "Priority override flag not cleared" | |
| log_test( | |
| "PR #2: Flag cleanup", True, "Temporary flags cleared after scheduling" | |
| ) | |
| return True | |
| except Exception as e: | |
| log_test("PR #2: Flag cleanup", False, str(e)) | |
| return False | |
| def test_pr3_unknown_ripeness(): | |
| """Test PR #3: UNKNOWN ripeness status exists and is used.""" | |
| from src.core.ripeness import RipenessClassifier, RipenessStatus | |
| from src.data.case_generator import CaseGenerator | |
| try: | |
| # Verify UNKNOWN status exists | |
| assert hasattr(RipenessStatus, "UNKNOWN"), "RipenessStatus.UNKNOWN not found" | |
| # Create case with ambiguous ripeness | |
| gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42) | |
| cases = gen.generate(10) | |
| # Clear ripeness indicators to test UNKNOWN default | |
| test_case = cases[0] | |
| test_case.last_hearing_date = None | |
| test_case.service_status = None | |
| test_case.compliance_status = None | |
| # Classify ripeness | |
| ripeness = RipenessClassifier.classify(test_case, date(2024, 1, 15)) | |
| # Should default to UNKNOWN when no evidence | |
| assert ripeness == RipenessStatus.UNKNOWN or not ripeness.is_ripe(), ( | |
| "Ambiguous case did not get UNKNOWN or non-RIPE status" | |
| ) | |
| log_test("PR #3: UNKNOWN ripeness", True, f"Status: {ripeness.value}") | |
| return True | |
| except Exception as e: | |
| log_test("PR #3: UNKNOWN ripeness", False, str(e)) | |
| return False | |
| def test_pr6_parameter_fallback(): | |
| """Test PR #6: Parameter fallback with bundled defaults.""" | |
| try: | |
| # Test that defaults directory exists | |
| defaults_dir = Path("scheduler/data/defaults") | |
| assert defaults_dir.exists(), f"Defaults directory not found: {defaults_dir}" | |
| # Check for expected default files | |
| expected_files = [ | |
| "stage_transition_probs.csv", | |
| "stage_duration.csv", | |
| "adjournment_proxies.csv", | |
| "court_capacity_global.json", | |
| "stage_transition_entropy.csv", | |
| "case_type_summary.csv", | |
| ] | |
| for file in expected_files: | |
| file_path = defaults_dir / file | |
| assert file_path.exists(), f"Default file missing: {file}" | |
| log_test( | |
| "PR #6: Parameter fallback", | |
| True, | |
| f"Found {len(expected_files)} default parameter files", | |
| ) | |
| return True | |
| except Exception as e: | |
| log_test("PR #6: Parameter fallback", False, str(e)) | |
| return False | |
| def test_pr4_rl_constraints(): | |
| """Test PR #4: RL training uses SchedulingAlgorithm with constraints.""" | |
| from rl.config import RLTrainingConfig | |
| from rl.training import RLTrainingEnvironment | |
| from src.data.case_generator import CaseGenerator | |
| try: | |
| # Create training environment | |
| gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42) | |
| cases = gen.generate(100) | |
| config = RLTrainingConfig( | |
| episodes=2, | |
| cases_per_episode=100, | |
| episode_length_days=10, | |
| courtrooms=2, | |
| daily_capacity_per_courtroom=50, | |
| enforce_min_gap=True, | |
| cap_daily_allocations=True, | |
| apply_judge_preferences=True, | |
| ) | |
| env = RLTrainingEnvironment( | |
| cases=cases, start_date=date(2024, 1, 1), horizon_days=10, rl_config=config | |
| ) | |
| # Verify SchedulingAlgorithm components exist | |
| assert hasattr(env, "algorithm"), ( | |
| "No SchedulingAlgorithm in training environment" | |
| ) | |
| assert hasattr(env, "courtrooms"), "No courtrooms in training environment" | |
| assert hasattr(env, "allocator"), "No allocator in training environment" | |
| assert hasattr(env, "policy"), "No policy in training environment" | |
| # Test step with agent decisions | |
| agent_decisions = {cases[0].case_id: 1, cases[1].case_id: 1} | |
| updated_cases, rewards, done = env.step(agent_decisions) | |
| assert len(rewards) >= 0, "No rewards returned from step" | |
| log_test( | |
| "PR #4: RL constraints", | |
| True, | |
| f"Environment has algorithm, courtrooms, allocator. Capacity enforced: {config.cap_daily_allocations}", | |
| ) | |
| return True | |
| except Exception as e: | |
| log_test("PR #4: RL constraints", False, str(e)) | |
| return False | |
| def test_pr5_shared_rewards(): | |
| """Test PR #5: Shared reward helper exists and is used.""" | |
| from rl.rewards import EpisodeRewardHelper | |
| from rl.training import RLTrainingEnvironment | |
| from src.data.case_generator import CaseGenerator | |
| try: | |
| # Verify EpisodeRewardHelper exists | |
| helper = EpisodeRewardHelper(total_cases=100) | |
| assert hasattr(helper, "compute_case_reward"), "No compute_case_reward method" | |
| # Verify training environment uses it | |
| gen = CaseGenerator(start=date(2024, 1, 1), end=date(2024, 1, 10), seed=42) | |
| cases = gen.generate(50) | |
| env = RLTrainingEnvironment(cases, date(2024, 1, 1), 10) | |
| assert hasattr(env, "reward_helper"), ( | |
| "Training environment doesn't use reward_helper" | |
| ) | |
| assert isinstance(env.reward_helper, EpisodeRewardHelper), ( | |
| "reward_helper is not EpisodeRewardHelper instance" | |
| ) | |
| # Test reward computation | |
| test_case = cases[0] | |
| reward = env.reward_helper.compute_case_reward( | |
| case=test_case, | |
| was_scheduled=True, | |
| hearing_outcome="PROGRESS", | |
| current_date=date(2024, 1, 15), | |
| previous_gap_days=30, | |
| ) | |
| assert isinstance(reward, float), "Reward is not a float" | |
| log_test( | |
| "PR #5: Shared rewards", | |
| True, | |
| f"Helper integrated, sample reward: {reward:.2f}", | |
| ) | |
| return True | |
| except Exception as e: | |
| log_test("PR #5: Shared rewards", False, str(e)) | |
| return False | |
| def test_pr7_metadata_tracking(): | |
| """Test PR #7: Output metadata tracking.""" | |
| from src.utils.output_manager import OutputManager | |
| try: | |
| # Create output manager | |
| output = OutputManager(run_id="test_run") | |
| output.create_structure() | |
| # Verify metadata methods exist | |
| assert hasattr(output, "record_eda_metadata"), "No record_eda_metadata method" | |
| assert hasattr(output, "save_training_stats"), "No save_training_stats method" | |
| assert hasattr(output, "save_evaluation_stats"), ( | |
| "No save_evaluation_stats method" | |
| ) | |
| assert hasattr(output, "record_simulation_kpis"), ( | |
| "No record_simulation_kpis method" | |
| ) | |
| # Verify run_record file created | |
| assert output.run_record_file.exists(), "run_record.json not created" | |
| # Test metadata recording | |
| output.record_eda_metadata( | |
| version="test_v1", | |
| used_cached=False, | |
| params_path=Path("test_params"), | |
| figures_path=Path("test_figures"), | |
| ) | |
| # Verify metadata was written | |
| import json | |
| with open(output.run_record_file, "r") as f: | |
| record = json.load(f) | |
| assert "sections" in record, "No sections in run_record" | |
| assert "eda" in record["sections"], "EDA metadata not recorded" | |
| log_test( | |
| "PR #7: Metadata tracking", | |
| True, | |
| f"Run record created with {len(record['sections'])} sections", | |
| ) | |
| return True | |
| except Exception as e: | |
| log_test("PR #7: Metadata tracking", False, str(e)) | |
| return False | |
| def run_all_tests(): | |
| """Run all enhancement tests.""" | |
| print("=" * 60) | |
| print("Testing Merged Enhancements") | |
| print("=" * 60) | |
| print() | |
| # PR #2 tests | |
| print("PR #2: Override Handling Refactor") | |
| print("-" * 40) | |
| test_pr2_override_validation() | |
| test_pr2_flag_cleanup() | |
| print() | |
| # PR #3 tests | |
| print("PR #3: Ripeness UNKNOWN State") | |
| print("-" * 40) | |
| test_pr3_unknown_ripeness() | |
| print() | |
| # PR #6 tests | |
| print("PR #6: Parameter Fallback") | |
| print("-" * 40) | |
| test_pr6_parameter_fallback() | |
| print() | |
| # PR #4 tests | |
| print("PR #4: RL Training Alignment") | |
| print("-" * 40) | |
| test_pr4_rl_constraints() | |
| print() | |
| # PR #5 tests | |
| print("PR #5: Shared Reward Helper") | |
| print("-" * 40) | |
| test_pr5_shared_rewards() | |
| print() | |
| # PR #7 tests | |
| print("PR #7: Output Metadata Tracking") | |
| print("-" * 40) | |
| test_pr7_metadata_tracking() | |
| print() | |
| # Summary | |
| print("=" * 60) | |
| print("Test Summary") | |
| print("=" * 60) | |
| print(f"Passed: {len(TESTS_PASSED)}") | |
| print(f"Failed: {len(TESTS_FAILED)}") | |
| print() | |
| if TESTS_FAILED: | |
| print("Failed tests:") | |
| for test in TESTS_FAILED: | |
| print(f" - {test}") | |
| return 1 | |
| else: | |
| print("All tests passed!") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(run_all_tests()) | |