#!/usr/bin/env python3 """Rule-based AdaptShield baseline with evaluator-style stdout.""" from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any, Dict, List REPO_ROOT = Path(__file__).resolve().parent if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from models import AdaptShieldAction from server.adaptshield_environment import AdaptShieldEnvironment TASKS = ["direct-triage", "dual-pivot", "polymorphic-zero-day"] BENCHMARK = "adaptshield" MODEL_NAME = "rule-baseline" MAX_STEPS = 30 POLICY = { "brute_force": ("auth_service", "rate_limit"), "lateral_movement": ("payment_service", "isolate"), "exfiltration": ("database", "honeypot"), "supply_chain": ("api_gateway", "patch"), "benign": ("api_gateway", "monitor"), } def log_start(task: str) -> None: print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True) def log_step(step: int, action: Dict[str, Any], reward: float, done: bool) -> None: action_str = json.dumps(action, separators=(",", ":")) if len(action_str) > 100: action_str = action_str[:97] + "..." print( f"[STEP] step={step} action={action_str} " f"reward={reward:.2f} done={str(done).lower()} error=null", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: reward_str = ",".join(f"{reward:.2f}" for reward in rewards) print( f"[END] success={str(success).lower()} steps={steps} " f"score={score:.3f} rewards={reward_str}", flush=True, ) def print_replay(task: str, metadata: Dict[str, Any]) -> None: replay = metadata.get("episode_replay") or [] print() print(f"Replay: {task}") if not replay: print(" No replay metadata available.") return for row in replay: shift = " STRATEGY SHIFT:" if row.get("shift") else "" foothold = " FOOTHOLD:" if row.get("foothold_transition") else "" blast_radius = row.get("blast_radius") or [] blast = ",".join(blast_radius) if blast_radius else "none" active_defenses = row.get("active_defenses") or [] print( f" Turn {row.get('turn')}:" f"{shift}{foothold} {row.get('p1')} -> {row.get('p2_action')}" f"({row.get('target', 'unknown')}) -> {row.get('result')}" f" | impact={float(row.get('impact', 0.0)):.2f}" f" blast_radius={blast}" f" active_controls={len(active_defenses)}" f" mission={row.get('mission_alignment', 'neutral')}" ) def classify_from_metrics(network_nodes: Dict[str, Dict[str, Any]]) -> str: auth = network_nodes.get("auth_service", {}) payment = network_nodes.get("payment_service", {}) database = network_nodes.get("database", {}) gateway = network_nodes.get("api_gateway", {}) if float(auth.get("error_rate", 0.0)) >= 0.10: return "brute_force" if payment.get("status") == "suspicious" or float(payment.get("cpu", 0)) >= 55: return "lateral_movement" if float(database.get("outbound_mb", 0)) >= 50: return "exfiltration" if gateway.get("status") == "suspicious": return "supply_chain" return "benign" def phase1_payload(obs) -> Dict[str, Any]: threat_type = classify_from_metrics(obs.network_nodes) target_node, action = POLICY[threat_type] return { "threat_type": threat_type, "confidence": 0.90, "target_node": target_node, "recommended_action": action, "reasoning": "rule-based metric classifier", } def phase2_payload(obs) -> Dict[str, Any]: assessment = obs.phase1_assessment or {} threat_type = str(assessment.get("threat_type", "benign")) fallback_target, fallback_action = POLICY.get(threat_type, POLICY["benign"]) action = str(assessment.get("recommended_action") or fallback_action) target_node = str(assessment.get("target_node") or fallback_target) return { "action": action, "target_node": target_node, "reasoning": "execute analyst recommendation", } def action_from_payload(payload: Dict[str, Any]) -> AdaptShieldAction: return AdaptShieldAction(**payload) def run_task(task: str, emit_logs: bool = True) -> Dict[str, Any]: env = AdaptShieldEnvironment(task_name=task) obs = env.reset() rewards: List[float] = [] steps = 0 if emit_logs: log_start(task) while not obs.done and steps < MAX_STEPS: if obs.phase == 1: payload = phase1_payload(obs) else: payload = phase2_payload(obs) obs = env.step(action_from_payload(payload)) reward = float(obs.reward) rewards.append(reward) steps += 1 if emit_logs: log_step(steps, payload, reward, obs.done) metadata = obs.metadata if isinstance(obs.metadata, dict) else {} score = float(metadata.get("normalized_score", 0.01)) success = obs.done and 0.01 <= score <= 0.99 if emit_logs: log_end(success, steps, score, rewards) return { "task": task, "score": score, "steps": steps, "done": bool(obs.done), "rewards": rewards, "metadata": metadata, "normalized_score_present": "normalized_score" in metadata, "success": success, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run AdaptShield rule baseline.") parser.add_argument( "--task", default="direct-triage", choices=TASKS + ["all"], help="Task to run, or 'all' for every task.", ) parser.add_argument( "--replay", action="store_true", help="Print a human-readable final episode replay.", ) return parser.parse_args() def main() -> int: args = parse_args() tasks = TASKS if args.task == "all" else [args.task] for index, task in enumerate(tasks): if index: print() result = run_task(task, emit_logs=True) if args.replay: print_replay(task, result["metadata"]) return 0 if __name__ == "__main__": raise SystemExit(main())