Spaces:
Sleeping
Sleeping
| """ | |
| Alignment Tests (env v2) — The critical tests v1 didn't have. | |
| These check that step rewards and final grades are consistent (no | |
| step-reward / grader misalignment) and that hidden objectives are | |
| properly hidden in observations. | |
| """ | |
| import sys, os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from server.environment import GeoPolicyEnv | |
| from models.action import GeoAction | |
| def _fresh_env(task_id="task3"): | |
| env = GeoPolicyEnv() | |
| env.reset(task_id=task_id, seed=42) | |
| return env | |
| # ============================================================ | |
| # Alignment tests | |
| # ============================================================ | |
| def test_step_and_final_share_components(): | |
| """Step reward and final grade must use the same set of rubric components. | |
| This is the core invariant that prevents step-reward / grader misalignment. | |
| """ | |
| env = _fresh_env("task3") | |
| rankings = env.get_rankings() | |
| step = env.task_rubric.step_reward("aria", env.countries, {}, rankings) | |
| final = env.task_rubric.final_grade("aria", env.countries, rankings) | |
| assert set(step["components"].keys()) == set(final["components"].keys()) | |
| print("PASS: step and final share component names") | |
| def test_step_and_final_use_same_weights(): | |
| """Both step_reward and final_grade pull from TASK_WEIGHTS[task_id].""" | |
| env = _fresh_env("task3") | |
| # Mutate weights via the rubric instance and verify both methods reflect it | |
| rubric = env.task_rubric | |
| expected_keys = set(rubric.weights.keys()) | |
| step = rubric.step_reward("aria", env.countries, {}, env.get_rankings()) | |
| final = rubric.final_grade("aria", env.countries, env.get_rankings()) | |
| assert set(step["components"].keys()) == expected_keys | |
| assert set(final["components"].keys()) == expected_keys | |
| print("PASS: step and final use the same per-task weights") | |
| def test_grade_country_returns_blended_total(): | |
| """grade_country() should return same number as task_rubric.final_grade()['total'].""" | |
| env = _fresh_env("task3") | |
| rankings = env.get_rankings() | |
| detailed = env.task_rubric.final_grade("aria", env.countries, rankings) | |
| grade = env.grade_country("aria") | |
| assert abs(grade - round(detailed["total"], 4)) < 1e-6 | |
| print("PASS: grade_country == task_rubric.final_grade()['total']") | |
| # ============================================================ | |
| # Observation masking | |
| # ============================================================ | |
| def test_observation_includes_own_objective(): | |
| """The country's own observation should include its hidden objective.""" | |
| env = _fresh_env("task3") | |
| obs = env.get_observation("aria") | |
| assert obs.hidden_objective_id is not None | |
| assert obs.hidden_objective_name is not None | |
| assert obs.hidden_objective_description is not None | |
| print(f"PASS: own observation includes objective ({obs.hidden_objective_id})") | |
| def test_observation_masks_others_objectives(): | |
| """Other countries' info in the observation must NOT include hidden_objective.""" | |
| env = _fresh_env("task3") | |
| obs = env.get_observation("aria") | |
| for other_id, info in obs.other_countries.items(): | |
| # PublicCountryInfo schema must not have hidden_objective_id | |
| assert not hasattr(info, "hidden_objective_id"), ( | |
| f"Other country {other_id} leaks hidden_objective_id" | |
| ) | |
| assert "hidden_objective" not in info.dict() | |
| print("PASS: others' observations do not leak hidden objectives") | |
| def test_task1_disables_hidden_objectives(): | |
| """Task 1 (bilateral) should not assign hidden objectives.""" | |
| env = GeoPolicyEnv() | |
| env.reset(task_id="task1", seed=0) | |
| for cid in env.active_country_ids: | |
| assert env.countries[cid].hidden_objective is None | |
| obs = env.get_observation(env.active_country_ids[0]) | |
| assert obs.hidden_objective_id is None | |
| print("PASS: task1 has no hidden objectives assigned") | |
| def test_task2_assigns_hidden_objectives(): | |
| """Task 2 should assign one objective per country.""" | |
| env = GeoPolicyEnv() | |
| env.reset(task_id="task2", seed=0) | |
| objectives = [env.countries[cid].hidden_objective for cid in env.active_country_ids] | |
| assert all(o is not None for o in objectives) | |
| assert len(set(objectives)) == len(objectives) # all unique | |
| print(f"PASS: task2 assigns {len(objectives)} unique objectives") | |
| def test_task3_assigns_hidden_objectives(): | |
| """Task 3 should assign one objective per country.""" | |
| env = GeoPolicyEnv() | |
| env.reset(task_id="task3", seed=0) | |
| objectives = [env.countries[cid].hidden_objective for cid in env.active_country_ids] | |
| assert all(o is not None for o in objectives) | |
| assert len(set(objectives)) == len(objectives) | |
| print(f"PASS: task3 assigns {len(objectives)} unique objectives") | |
| # ============================================================ | |
| # Snapshot/restore preserves objectives | |
| # ============================================================ | |
| def test_snapshot_preserves_hidden_objectives(): | |
| """Snapshot/restore must preserve hidden objective assignments and rubric.""" | |
| env = _fresh_env("task3") | |
| original = {cid: env.countries[cid].hidden_objective for cid in env.active_country_ids} | |
| snap = env.snapshot() | |
| # Mutate | |
| for cid in env.active_country_ids: | |
| env.countries[cid].hidden_objective = "KINGMAKER" | |
| # Restore | |
| env.restore(snap) | |
| restored = {cid: env.countries[cid].hidden_objective for cid in env.active_country_ids} | |
| assert restored == original | |
| # Rubric should still work post-restore | |
| out = env.task_rubric.step_reward("aria", env.countries, {}, env.get_rankings()) | |
| assert "total" in out | |
| print("PASS: snapshot/restore preserves hidden objectives and rubric") | |
| # ============================================================ | |
| # Reward contains components in metadata | |
| # ============================================================ | |
| def test_step_metadata_includes_components(): | |
| """env.step() should attach reward_components to obs.metadata for logging.""" | |
| env = _fresh_env("task3") | |
| a = GeoAction(action_type="WAIT", source_country="aria") | |
| obs = env.step(a) | |
| assert "reward_components" in obs.metadata | |
| assert set(obs.metadata["reward_components"].keys()) == { | |
| "economic", "diplomatic", "military", "stability", "hidden" | |
| } | |
| print("PASS: step metadata includes reward_components") | |
| # ============================================================ | |
| # Reward in valid range | |
| # ============================================================ | |
| def test_step_reward_always_in_unit_interval(): | |
| """Run a 12-turn task3 episode and verify all rewards are in [0, 1].""" | |
| env = _fresh_env("task3") | |
| countries = list(env.active_country_ids) | |
| for turn in range(12): | |
| for cid in countries: | |
| a = GeoAction(action_type="WAIT", source_country=cid) | |
| obs = env.step(a) | |
| assert 0.0 <= obs.reward <= 1.0, f"reward {obs.reward} out of range" | |
| if obs.done: | |
| break | |
| if obs.done: | |
| break | |
| print("PASS: all step rewards in [0, 1] across full episode") | |
| if __name__ == "__main__": | |
| tests = [ | |
| test_step_and_final_share_components, | |
| test_step_and_final_use_same_weights, | |
| test_grade_country_returns_blended_total, | |
| test_observation_includes_own_objective, | |
| test_observation_masks_others_objectives, | |
| test_task1_disables_hidden_objectives, | |
| test_task2_assigns_hidden_objectives, | |
| test_task3_assigns_hidden_objectives, | |
| test_snapshot_preserves_hidden_objectives, | |
| test_step_metadata_includes_components, | |
| test_step_reward_always_in_unit_interval, | |
| ] | |
| passed = failed = 0 | |
| for t in tests: | |
| try: t(); passed += 1 | |
| except Exception as e: print(f"FAIL: {t.__name__} — {e}"); failed += 1 | |
| print(f"\n{'='*50}\nResults: {passed} passed, {failed} failed out of {len(tests)}\n{'='*50}") | |