Spaces:
Running
Running
| import sys | |
| import unittest | |
| from pathlib import Path | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| PACKAGE_ROOT = REPO_ROOT / "adaptshield" | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| if str(PACKAGE_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PACKAGE_ROOT)) | |
| import __init__ as adaptshield | |
| import server.app as server_app | |
| import train as train_module | |
| from client import AdaptshieldEnv | |
| from models import AdaptShieldAction | |
| from server.adaptshield_environment import AdaptShieldEnvironment | |
| from server.grader import normalize_episode_score, _required_tool_fusion | |
| class PackageRegressionTests(unittest.TestCase): | |
| def test_package_import_exports_expected_symbols(self) -> None: | |
| self.assertIn("AdaptShieldAction", adaptshield.__all__) | |
| self.assertIn("AdaptShieldObservation", adaptshield.__all__) | |
| self.assertIn("AdaptshieldEnv", adaptshield.__all__) | |
| def test_server_app_imports_fastapi_instance(self) -> None: | |
| self.assertEqual(server_app.app.__class__.__name__, "FastAPI") | |
| class EnvironmentRegressionTests(unittest.TestCase): | |
| def test_phase_flow_accepts_both_action_shapes(self) -> None: | |
| env = AdaptShieldEnvironment("direct-triage") | |
| phase1_obs = env.reset() | |
| self.assertEqual(phase1_obs.phase, 1) | |
| self.assertEqual(phase1_obs.turn, 1) | |
| self.assertEqual(phase1_obs.metadata["normalized_score"], 0.50) | |
| self.assertIn("mission_profile", phase1_obs.metadata) | |
| self.assertEqual(phase1_obs.metadata["world_split"], "train") | |
| self.assertIn(phase1_obs.metadata["world_family"], {"train-a", "train-b"}) | |
| phase2_obs = env.step( | |
| AdaptShieldAction( | |
| threat_type="brute_force", | |
| confidence=0.9, | |
| target_node="auth_service", | |
| recommended_action="rate_limit", | |
| ) | |
| ) | |
| self.assertEqual(phase2_obs.phase, 2) | |
| self.assertEqual(phase2_obs.phase1_assessment["recommended_action"], "rate_limit") | |
| next_turn_obs = env.step( | |
| AdaptShieldAction(action="rate_limit", target_node="auth_service") | |
| ) | |
| self.assertEqual(next_turn_obs.phase, 1) | |
| self.assertGreaterEqual(next_turn_obs.reward, 0.65) | |
| self.assertIn("requires stronger SOC evidence", next_turn_obs.last_action_result) | |
| self.assertIn("business_impact", next_turn_obs.metadata["score_breakdown"]) | |
| self.assertIn("dependency_blast_radius", next_turn_obs.metadata["score_breakdown"]) | |
| self.assertIn("mission_alignment", next_turn_obs.metadata["score_breakdown"]) | |
| self.assertIn("active_defenses", next_turn_obs.metadata) | |
| self.assertIn("available_tools", next_turn_obs.metadata) | |
| tool_names = {tool["name"] for tool in next_turn_obs.metadata["available_tools"]} | |
| self.assertTrue({ | |
| "identity_lookup", | |
| "change_calendar_lookup", | |
| "netflow_lookup", | |
| }.issubset(tool_names)) | |
| env = AdaptShieldEnvironment("direct-triage") | |
| env.reset() | |
| env.call_tool("log_search", node="auth_service") | |
| env.step( | |
| AdaptShieldAction( | |
| threat_type="brute_force", | |
| confidence=0.9, | |
| target_node="auth_service", | |
| recommended_action="rate_limit", | |
| ) | |
| ) | |
| verified_obs = env.step( | |
| AdaptShieldAction(action="rate_limit", target_node="auth_service") | |
| ) | |
| self.assertGreaterEqual(verified_obs.reward, 0.9) | |
| self.assertIn("Optimal: rate_limit", verified_obs.last_action_result) | |
| def test_client_payload_omits_empty_metadata_and_serializes_enums(self) -> None: | |
| client = AdaptshieldEnv(base_url="http://localhost:7860") | |
| phase1_payload = client._step_payload( | |
| AdaptShieldAction( | |
| threat_type="benign", | |
| confidence=0.8, | |
| target_node="auth_service", | |
| recommended_action="monitor", | |
| ) | |
| ) | |
| self.assertEqual( | |
| phase1_payload, | |
| { | |
| "threat_type": "benign", | |
| "confidence": 0.8, | |
| "target_node": "auth_service", | |
| "recommended_action": "monitor", | |
| }, | |
| ) | |
| phase2_payload = client._step_payload( | |
| AdaptShieldAction(action="rate_limit", target_node="auth_service") | |
| ) | |
| self.assertEqual( | |
| phase2_payload, | |
| {"action": "rate_limit", "target_node": "auth_service"}, | |
| ) | |
| def test_hard_task_records_verified_tool_evidence(self) -> None: | |
| env = AdaptShieldEnvironment("polymorphic-zero-day") | |
| for _ in range(8): | |
| obs = env.reset() | |
| turn_config = dict(getattr(env, "_turn_config", {}) or {}) | |
| if not turn_config.get("is_benign", False): | |
| break | |
| else: | |
| self.fail("Expected a non-benign hard-task reset within 8 attempts") | |
| self.assertIn("available_tools", obs.metadata) | |
| self.assertNotIn("foothold_established", obs.metadata) | |
| target = str(turn_config.get("correct_target", "auth_service")) | |
| for tool_name in sorted(_required_tool_fusion("polymorphic-zero-day", str(turn_config.get("strategy", "benign")))): | |
| tool_result = env.call_tool(tool_name, node=target) | |
| self.assertNotIn("verified", tool_result) | |
| self.assertNotIn("evidence_type", tool_result) | |
| self.assertTrue(tool_result.get("result_summary")) | |
| env.step( | |
| AdaptShieldAction( | |
| threat_type=turn_config.get("strategy", "brute_force"), | |
| confidence=0.9, | |
| target_node=target, | |
| recommended_action=turn_config.get("correct_action", "monitor"), | |
| ) | |
| ) | |
| obs = env.step( | |
| AdaptShieldAction( | |
| action=turn_config.get("correct_action", "monitor"), | |
| target_node=target, | |
| ) | |
| ) | |
| breakdown = obs.metadata["score_breakdown"] | |
| self.assertTrue(breakdown["tool_verification_required"]) | |
| self.assertTrue(breakdown["tool_evidence_found"]) | |
| self.assertGreaterEqual(obs.reward, 0.65) | |
| def test_enterprise_context_tools_return_public_fields_only(self) -> None: | |
| env = AdaptShieldEnvironment("polymorphic-zero-day") | |
| env.reset() | |
| identity = env.call_tool("identity_lookup", node="payment_service") | |
| self.assertIn("account", identity) | |
| self.assertIn("recent_source_host", identity) | |
| self.assertNotIn("verified", identity) | |
| self.assertNotIn("evidence_type", identity) | |
| change = env.call_tool("change_calendar_lookup", node="api_gateway") | |
| self.assertIn("scheduled", change) | |
| self.assertIn("change_status", change) | |
| self.assertNotIn("verified", change) | |
| self.assertNotIn("evidence_type", change) | |
| netflow = env.call_tool("netflow_lookup", node="database") | |
| self.assertIn("traffic_pattern", netflow) | |
| self.assertIn("east_west_connections", netflow) | |
| self.assertNotIn("verified", netflow) | |
| self.assertNotIn("evidence_type", netflow) | |
| def test_dual_pivot_requires_tool_confirmation_after_pivot(self) -> None: | |
| env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first") | |
| env.reset() | |
| for _ in range(3): | |
| turn_config = dict(getattr(env, "_turn_config", {}) or {}) | |
| env.step( | |
| AdaptShieldAction( | |
| threat_type=str(turn_config.get("strategy", "brute_force")), | |
| confidence=0.9, | |
| target_node=str(turn_config.get("correct_target", "auth_service")), | |
| recommended_action=str(turn_config.get("correct_action", "monitor")), | |
| ) | |
| ) | |
| obs = env.step( | |
| AdaptShieldAction( | |
| action=str(turn_config.get("correct_action", "monitor")), | |
| target_node=str(turn_config.get("correct_target", "auth_service")), | |
| ) | |
| ) | |
| self.assertFalse(obs.done) | |
| turn_config = dict(getattr(env, "_turn_config", {}) or {}) | |
| self.assertEqual(turn_config.get("strategy"), "lateral_movement") | |
| target = str(turn_config.get("correct_target", "payment_service")) | |
| env.step( | |
| AdaptShieldAction( | |
| threat_type="lateral_movement", | |
| confidence=0.9, | |
| target_node=target, | |
| recommended_action=str(turn_config.get("correct_action", "isolate")), | |
| ) | |
| ) | |
| obs = env.step( | |
| AdaptShieldAction( | |
| action=str(turn_config.get("correct_action", "isolate")), | |
| target_node=target, | |
| ) | |
| ) | |
| self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"]) | |
| self.assertFalse(obs.metadata["score_breakdown"]["tool_evidence_found"]) | |
| self.assertIn("requires stronger SOC evidence", obs.last_action_result) | |
| env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first") | |
| env.reset() | |
| for _ in range(3): | |
| turn_config = dict(getattr(env, "_turn_config", {}) or {}) | |
| env.step( | |
| AdaptShieldAction( | |
| threat_type=str(turn_config.get("strategy", "brute_force")), | |
| confidence=0.9, | |
| target_node=str(turn_config.get("correct_target", "auth_service")), | |
| recommended_action=str(turn_config.get("correct_action", "monitor")), | |
| ) | |
| ) | |
| env.step( | |
| AdaptShieldAction( | |
| action=str(turn_config.get("correct_action", "monitor")), | |
| target_node=str(turn_config.get("correct_target", "auth_service")), | |
| ) | |
| ) | |
| turn_config = dict(getattr(env, "_turn_config", {}) or {}) | |
| target = str(turn_config.get("correct_target", "payment_service")) | |
| env.call_tool("edr_status", node=target) | |
| env.call_tool("log_search", node=target) | |
| env.call_tool("identity_lookup", node=target) | |
| env.step( | |
| AdaptShieldAction( | |
| threat_type="lateral_movement", | |
| confidence=0.9, | |
| target_node=target, | |
| recommended_action=str(turn_config.get("correct_action", "isolate")), | |
| ) | |
| ) | |
| obs = env.step( | |
| AdaptShieldAction( | |
| action=str(turn_config.get("correct_action", "isolate")), | |
| target_node=target, | |
| ) | |
| ) | |
| self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"]) | |
| self.assertTrue(obs.metadata["score_breakdown"]["tool_evidence_found"]) | |
| self.assertIn("Optimal: isolate", obs.last_action_result) | |
| def test_world_family_metadata_and_surfaces_are_selectable(self) -> None: | |
| env = AdaptShieldEnvironment( | |
| "direct-triage", | |
| world_split="eval", | |
| world_family="eval-x", | |
| ) | |
| obs = env.reset() | |
| self.assertEqual(obs.metadata["world_split"], "eval") | |
| self.assertEqual(obs.metadata["world_family"], "eval-x") | |
| alerts_blob = " ".join(obs.active_alerts).lower() | |
| self.assertTrue( | |
| "auth rejection burst" in alerts_blob or | |
| "credential reuse sweep" in alerts_blob | |
| ) | |
| def test_operational_modes_change_medium_and_hard_optimal_actions(self) -> None: | |
| medium_env = AdaptShieldEnvironment( | |
| "dual-pivot", | |
| operational_mode="evidence_preservation", | |
| world_family="train-b", | |
| ) | |
| medium_env.reset() | |
| for _ in range(3): | |
| turn_config = dict(getattr(medium_env, "_turn_config", {}) or {}) | |
| medium_env.step( | |
| AdaptShieldAction( | |
| threat_type=str(turn_config.get("strategy", "brute_force")), | |
| confidence=0.9, | |
| target_node=str(turn_config.get("correct_target", "auth_service")), | |
| recommended_action=str(turn_config.get("correct_action", "monitor")), | |
| ) | |
| ) | |
| medium_env.step( | |
| AdaptShieldAction( | |
| action=str(turn_config.get("correct_action", "monitor")), | |
| target_node=str(turn_config.get("correct_target", "auth_service")), | |
| ) | |
| ) | |
| self.assertEqual(getattr(medium_env, "_turn_config", {}).get("strategy"), "lateral_movement") | |
| self.assertEqual(getattr(medium_env, "_turn_config", {}).get("correct_action"), "honeypot") | |
| hard_env = AdaptShieldEnvironment( | |
| "polymorphic-zero-day", | |
| operational_mode="forensic_hold", | |
| world_family="eval-y", | |
| ) | |
| hard_obs = hard_env.reset() | |
| adjusted = hard_env._apply_operational_mode({ | |
| "strategy": "exfiltration", | |
| "attack_stage": "exploit", | |
| "is_benign": False, | |
| "correct_action": "isolate", | |
| "correct_target": "database", | |
| "network_nodes": {"payment_service": {"status": "healthy", "request_rate": 85}}, | |
| "active_alerts": [], | |
| }) | |
| self.assertEqual(hard_obs.metadata["operational_mode"], "forensic_hold") | |
| self.assertEqual(adjusted.get("correct_action"), "honeypot") | |
| def test_prompt_bank_builds_phase_rows_without_gpu_deps(self) -> None: | |
| rows = train_module.build_prompt_bank( | |
| tokenizer=None, | |
| selected_task="all", | |
| curriculum=True, | |
| rollout_episodes=3, | |
| max_steps=6, | |
| use_tools=True, | |
| seed=42, | |
| ) | |
| self.assertTrue(rows) | |
| phases = {int(row["phase"]) for row in rows} | |
| tasks = {str(row["task"]) for row in rows} | |
| self.assertIn(1, phases) | |
| self.assertIn(2, phases) | |
| self.assertTrue(tasks.intersection({"direct-triage", "dual-pivot", "polymorphic-zero-day"})) | |
| hard_rows = [row for row in rows if row["task"] == "polymorphic-zero-day"] | |
| self.assertTrue(hard_rows) | |
| self.assertTrue(any(int(row["tool_calls"]) >= 2 for row in hard_rows)) | |
| def test_normalized_score_uses_reachable_reward_ceiling(self) -> None: | |
| rewards = [0.99] * 10 | |
| self.assertEqual(normalize_episode_score(rewards), 0.99) | |
| if __name__ == "__main__": | |
| unittest.main() | |