import sys import unittest from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] PACKAGE_ROOT = REPO_ROOT / "adaptshield" if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) if str(PACKAGE_ROOT) not in sys.path: sys.path.insert(0, str(PACKAGE_ROOT)) import __init__ as adaptshield import server.app as server_app import train as train_module from client import AdaptshieldEnv from models import AdaptShieldAction from server.adaptshield_environment import AdaptShieldEnvironment from server.grader import normalize_episode_score, _required_tool_fusion class PackageRegressionTests(unittest.TestCase): def test_package_import_exports_expected_symbols(self) -> None: self.assertIn("AdaptShieldAction", adaptshield.__all__) self.assertIn("AdaptShieldObservation", adaptshield.__all__) self.assertIn("AdaptshieldEnv", adaptshield.__all__) def test_server_app_imports_fastapi_instance(self) -> None: self.assertEqual(server_app.app.__class__.__name__, "FastAPI") class EnvironmentRegressionTests(unittest.TestCase): def test_phase_flow_accepts_both_action_shapes(self) -> None: env = AdaptShieldEnvironment("direct-triage") phase1_obs = env.reset() self.assertEqual(phase1_obs.phase, 1) self.assertEqual(phase1_obs.turn, 1) self.assertEqual(phase1_obs.metadata["normalized_score"], 0.50) self.assertIn("mission_profile", phase1_obs.metadata) self.assertEqual(phase1_obs.metadata["world_split"], "train") self.assertIn(phase1_obs.metadata["world_family"], {"train-a", "train-b"}) phase2_obs = env.step( AdaptShieldAction( threat_type="brute_force", confidence=0.9, target_node="auth_service", recommended_action="rate_limit", ) ) self.assertEqual(phase2_obs.phase, 2) self.assertEqual(phase2_obs.phase1_assessment["recommended_action"], "rate_limit") next_turn_obs = env.step( AdaptShieldAction(action="rate_limit", target_node="auth_service") ) self.assertEqual(next_turn_obs.phase, 1) self.assertGreaterEqual(next_turn_obs.reward, 0.65) self.assertIn("requires stronger SOC evidence", next_turn_obs.last_action_result) self.assertIn("business_impact", next_turn_obs.metadata["score_breakdown"]) self.assertIn("dependency_blast_radius", next_turn_obs.metadata["score_breakdown"]) self.assertIn("mission_alignment", next_turn_obs.metadata["score_breakdown"]) self.assertIn("active_defenses", next_turn_obs.metadata) self.assertIn("available_tools", next_turn_obs.metadata) tool_names = {tool["name"] for tool in next_turn_obs.metadata["available_tools"]} self.assertTrue({ "identity_lookup", "change_calendar_lookup", "netflow_lookup", }.issubset(tool_names)) env = AdaptShieldEnvironment("direct-triage") env.reset() env.call_tool("log_search", node="auth_service") env.step( AdaptShieldAction( threat_type="brute_force", confidence=0.9, target_node="auth_service", recommended_action="rate_limit", ) ) verified_obs = env.step( AdaptShieldAction(action="rate_limit", target_node="auth_service") ) self.assertGreaterEqual(verified_obs.reward, 0.9) self.assertIn("Optimal: rate_limit", verified_obs.last_action_result) def test_client_payload_omits_empty_metadata_and_serializes_enums(self) -> None: client = AdaptshieldEnv(base_url="http://localhost:7860") phase1_payload = client._step_payload( AdaptShieldAction( threat_type="benign", confidence=0.8, target_node="auth_service", recommended_action="monitor", ) ) self.assertEqual( phase1_payload, { "threat_type": "benign", "confidence": 0.8, "target_node": "auth_service", "recommended_action": "monitor", }, ) phase2_payload = client._step_payload( AdaptShieldAction(action="rate_limit", target_node="auth_service") ) self.assertEqual( phase2_payload, {"action": "rate_limit", "target_node": "auth_service"}, ) def test_hard_task_records_verified_tool_evidence(self) -> None: env = AdaptShieldEnvironment("polymorphic-zero-day") for _ in range(8): obs = env.reset() turn_config = dict(getattr(env, "_turn_config", {}) or {}) if not turn_config.get("is_benign", False): break else: self.fail("Expected a non-benign hard-task reset within 8 attempts") self.assertIn("available_tools", obs.metadata) self.assertNotIn("foothold_established", obs.metadata) target = str(turn_config.get("correct_target", "auth_service")) for tool_name in sorted(_required_tool_fusion("polymorphic-zero-day", str(turn_config.get("strategy", "benign")))): tool_result = env.call_tool(tool_name, node=target) self.assertNotIn("verified", tool_result) self.assertNotIn("evidence_type", tool_result) self.assertTrue(tool_result.get("result_summary")) env.step( AdaptShieldAction( threat_type=turn_config.get("strategy", "brute_force"), confidence=0.9, target_node=target, recommended_action=turn_config.get("correct_action", "monitor"), ) ) obs = env.step( AdaptShieldAction( action=turn_config.get("correct_action", "monitor"), target_node=target, ) ) breakdown = obs.metadata["score_breakdown"] self.assertTrue(breakdown["tool_verification_required"]) self.assertTrue(breakdown["tool_evidence_found"]) self.assertGreaterEqual(obs.reward, 0.65) def test_enterprise_context_tools_return_public_fields_only(self) -> None: env = AdaptShieldEnvironment("polymorphic-zero-day") env.reset() identity = env.call_tool("identity_lookup", node="payment_service") self.assertIn("account", identity) self.assertIn("recent_source_host", identity) self.assertNotIn("verified", identity) self.assertNotIn("evidence_type", identity) change = env.call_tool("change_calendar_lookup", node="api_gateway") self.assertIn("scheduled", change) self.assertIn("change_status", change) self.assertNotIn("verified", change) self.assertNotIn("evidence_type", change) netflow = env.call_tool("netflow_lookup", node="database") self.assertIn("traffic_pattern", netflow) self.assertIn("east_west_connections", netflow) self.assertNotIn("verified", netflow) self.assertNotIn("evidence_type", netflow) def test_dual_pivot_requires_tool_confirmation_after_pivot(self) -> None: env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first") env.reset() for _ in range(3): turn_config = dict(getattr(env, "_turn_config", {}) or {}) env.step( AdaptShieldAction( threat_type=str(turn_config.get("strategy", "brute_force")), confidence=0.9, target_node=str(turn_config.get("correct_target", "auth_service")), recommended_action=str(turn_config.get("correct_action", "monitor")), ) ) obs = env.step( AdaptShieldAction( action=str(turn_config.get("correct_action", "monitor")), target_node=str(turn_config.get("correct_target", "auth_service")), ) ) self.assertFalse(obs.done) turn_config = dict(getattr(env, "_turn_config", {}) or {}) self.assertEqual(turn_config.get("strategy"), "lateral_movement") target = str(turn_config.get("correct_target", "payment_service")) env.step( AdaptShieldAction( threat_type="lateral_movement", confidence=0.9, target_node=target, recommended_action=str(turn_config.get("correct_action", "isolate")), ) ) obs = env.step( AdaptShieldAction( action=str(turn_config.get("correct_action", "isolate")), target_node=target, ) ) self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"]) self.assertFalse(obs.metadata["score_breakdown"]["tool_evidence_found"]) self.assertIn("requires stronger SOC evidence", obs.last_action_result) env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first") env.reset() for _ in range(3): turn_config = dict(getattr(env, "_turn_config", {}) or {}) env.step( AdaptShieldAction( threat_type=str(turn_config.get("strategy", "brute_force")), confidence=0.9, target_node=str(turn_config.get("correct_target", "auth_service")), recommended_action=str(turn_config.get("correct_action", "monitor")), ) ) env.step( AdaptShieldAction( action=str(turn_config.get("correct_action", "monitor")), target_node=str(turn_config.get("correct_target", "auth_service")), ) ) turn_config = dict(getattr(env, "_turn_config", {}) or {}) target = str(turn_config.get("correct_target", "payment_service")) env.call_tool("edr_status", node=target) env.call_tool("log_search", node=target) env.call_tool("identity_lookup", node=target) env.step( AdaptShieldAction( threat_type="lateral_movement", confidence=0.9, target_node=target, recommended_action=str(turn_config.get("correct_action", "isolate")), ) ) obs = env.step( AdaptShieldAction( action=str(turn_config.get("correct_action", "isolate")), target_node=target, ) ) self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"]) self.assertTrue(obs.metadata["score_breakdown"]["tool_evidence_found"]) self.assertIn("Optimal: isolate", obs.last_action_result) def test_world_family_metadata_and_surfaces_are_selectable(self) -> None: env = AdaptShieldEnvironment( "direct-triage", world_split="eval", world_family="eval-x", ) obs = env.reset() self.assertEqual(obs.metadata["world_split"], "eval") self.assertEqual(obs.metadata["world_family"], "eval-x") alerts_blob = " ".join(obs.active_alerts).lower() self.assertTrue( "auth rejection burst" in alerts_blob or "credential reuse sweep" in alerts_blob ) def test_operational_modes_change_medium_and_hard_optimal_actions(self) -> None: medium_env = AdaptShieldEnvironment( "dual-pivot", operational_mode="evidence_preservation", world_family="train-b", ) medium_env.reset() for _ in range(3): turn_config = dict(getattr(medium_env, "_turn_config", {}) or {}) medium_env.step( AdaptShieldAction( threat_type=str(turn_config.get("strategy", "brute_force")), confidence=0.9, target_node=str(turn_config.get("correct_target", "auth_service")), recommended_action=str(turn_config.get("correct_action", "monitor")), ) ) medium_env.step( AdaptShieldAction( action=str(turn_config.get("correct_action", "monitor")), target_node=str(turn_config.get("correct_target", "auth_service")), ) ) self.assertEqual(getattr(medium_env, "_turn_config", {}).get("strategy"), "lateral_movement") self.assertEqual(getattr(medium_env, "_turn_config", {}).get("correct_action"), "honeypot") hard_env = AdaptShieldEnvironment( "polymorphic-zero-day", operational_mode="forensic_hold", world_family="eval-y", ) hard_obs = hard_env.reset() adjusted = hard_env._apply_operational_mode({ "strategy": "exfiltration", "attack_stage": "exploit", "is_benign": False, "correct_action": "isolate", "correct_target": "database", "network_nodes": {"payment_service": {"status": "healthy", "request_rate": 85}}, "active_alerts": [], }) self.assertEqual(hard_obs.metadata["operational_mode"], "forensic_hold") self.assertEqual(adjusted.get("correct_action"), "honeypot") def test_prompt_bank_builds_phase_rows_without_gpu_deps(self) -> None: rows = train_module.build_prompt_bank( tokenizer=None, selected_task="all", curriculum=True, rollout_episodes=3, max_steps=6, use_tools=True, seed=42, ) self.assertTrue(rows) phases = {int(row["phase"]) for row in rows} tasks = {str(row["task"]) for row in rows} self.assertIn(1, phases) self.assertIn(2, phases) self.assertTrue(tasks.intersection({"direct-triage", "dual-pivot", "polymorphic-zero-day"})) hard_rows = [row for row in rows if row["task"] == "polymorphic-zero-day"] self.assertTrue(hard_rows) self.assertTrue(any(int(row["tool_calls"]) >= 2 for row in hard_rows)) def test_normalized_score_uses_reachable_reward_ceiling(self) -> None: rewards = [0.99] * 10 self.assertEqual(normalize_episode_score(rewards), 0.99) if __name__ == "__main__": unittest.main()