""" AdaptShield Grader Fully deterministic reward computation. No NLP, no LLM-as-judge, no external API calls. Runs in milliseconds. INVARIANT: All return values are strictly in [0.01, 0.99]. normalized_score is ALWAYS present in metadata — never missing. """ from typing import Any, Dict, List, Tuple # ── Reward constants ─────────────────────────────────────────────────────── BASE_REWARD = 0.50 # survival — agent gets this for staying alive P1_TYPE_BONUS = 0.15 # Phase 1 correctly identified threat type P1_TARGET_BONUS = 0.10 # Phase 1 correctly identified target node P2_OPTIMAL = 0.39 # correct + efficient action (0.50+0.39 = 0.89 base max) P2_HEAVY = 0.18 # correct action but heavy-handed P2_UNVERIFIED = 0.10 # right hard-task action without investigative evidence P2_WRONG = -0.25 # wrong action on real threat FALSE_POSITIVE = -0.39 # acted aggressively on benign event STAGE_ESCALATION = -0.10 # penalty applied when attack stage escalates CATASTROPHIC = -0.49 # database exfiltration completed → done=True # Optimal vs heavy-handed action mapping OPTIMAL_ACTION = { "brute_force": "rate_limit", "lateral_movement": "isolate", "exfiltration": "honeypot", "supply_chain": "patch", } HEAVY_ACTION = { "brute_force": "isolate", # works but unnecessary downtime "lateral_movement": "honeypot", # works but slower "exfiltration": "isolate", # works but database goes offline "supply_chain": "isolate", # works but gateway goes offline } # Catastrophic: exfiltration succeeds if agent uses wrong action 3+ times EXFIL_CATASTROPHIC_ACTIONS = ["monitor", "rate_limit", "patch"] # Operational model: criticality and dependency blast radius. # This is intentionally lightweight so it strengthens realism without making # rewards noisy or hard to train against. ASSET_CRITICALITY = { "auth_service": 0.70, "payment_service": 0.90, "database": 1.00, "api_gateway": 0.80, "none": 0.00, "unknown": 0.50, } SERVICE_DEPENDENCIES = { "auth_service": ["payment_service"], "payment_service": ["api_gateway"], "database": ["payment_service", "api_gateway"], "api_gateway": ["auth_service", "payment_service", "database"], "none": [], "unknown": [], } ACTION_DISRUPTION = { "monitor": 0.00, "patch": 0.06, "rate_limit": 0.10, "honeypot": 0.12, "isolate": 0.35, } MAX_OPERATIONAL_PENALTY = 0.05 MAX_MISSION_ADJUSTMENT = 0.04 BASE_REQUIRED_TOOL_FUSION = { "brute_force": {"log_search", "cmdb_lookup"}, "lateral_movement": {"edr_status", "log_search"}, "exfiltration": {"log_search", "edr_status"}, "supply_chain": {"vuln_lookup", "log_search"}, } TASK_REQUIRED_TOOL_FUSION = { "direct-triage": { "brute_force": {"log_search"}, }, "dual-pivot": { "lateral_movement": {"edr_status", "log_search", "identity_lookup"}, }, "polymorphic-zero-day": { "brute_force": {"log_search", "cmdb_lookup", "identity_lookup"}, "lateral_movement": {"edr_status", "log_search", "identity_lookup", "cmdb_lookup"}, "exfiltration": {"log_search", "edr_status", "netflow_lookup", "cmdb_lookup"}, "supply_chain": {"vuln_lookup", "log_search", "change_calendar_lookup", "cmdb_lookup"}, }, } def grade_step( phase1_action: Dict[str, Any], phase2_action: Dict[str, Any], turn_config: Dict[str, Any], stage: str, consecutive_wrong: int, task_name: str = "", foothold_established: bool = False, mission_profile: Dict[str, Any] | None = None, tool_context: Dict[str, Any] | None = None, ) -> Tuple[float, bool, Dict[str, Any]]: """ Grade a complete two-phase step. Args: phase1_action: Agent's Phase 1 output (threat assessment) phase2_action: Agent's Phase 2 output (defensive action) turn_config: Ground truth from AttackerEngine.build_observation() stage: Current attack stage (recon/exploit/exfiltration) consecutive_wrong: How many consecutive wrong actions agent has taken Returns: (reward, catastrophic_failure, info_dict) reward is ALWAYS in range [0.01, 0.99] """ is_benign = turn_config.get("is_benign", False) strategy = turn_config.get("strategy", "benign") correct_action = turn_config.get("correct_action", "monitor") correct_target = turn_config.get("correct_target", "none") mission_profile = mission_profile or {} tool_context = tool_context or {} contextual_countermeasure = ( task_name == "polymorphic-zero-day" and foothold_established and strategy == "lateral_movement" ) if contextual_countermeasure: correct_action = "honeypot" p1_type = str(phase1_action.get("threat_type", "")) p1_target = str(phase1_action.get("target_node", "")) p2_action = str(phase2_action.get("action", "monitor")) p2_target = str(phase2_action.get("target_node", "")) reward = BASE_REWARD catastrophic = False info = { "strategy": strategy, "stage": stage, "correct_action": correct_action, "correct_target": correct_target, "p1_threat_type": p1_type, "p1_target": p1_target, "p2_action": p2_action, "p2_target": p2_target, "is_benign": is_benign, "p1_type_correct": False, "p1_target_correct": False, "score_reason": "", "business_impact": 0.0, "availability_impact": 0.0, "security_risk": 0.0, "dependency_blast_radius": [], "operational_penalty": 0.0, "foothold_established": foothold_established, "contextual_countermeasure": contextual_countermeasure, "mission_profile": mission_profile, "mission_alignment": "neutral", "mission_adjustment": 0.0, "tool_verification_required": False, "tool_evidence_found": False, "tool_fusion_required": [], "tool_fusion_found": [], "tool_count": int(tool_context.get("tool_count", 0) or 0), } # ── False positive path ──────────────────────────────────────────────── if is_benign: if p2_action in ("rate_limit", "isolate", "honeypot", "patch"): reward += FALSE_POSITIVE info["acted_correctly"] = False info["score_reason"] = ( f"False positive: used {p2_action} on benign event" ) else: reward += 0.15 info["acted_correctly"] = True info["score_reason"] = "Correct: monitored benign event without overreacting" _apply_operational_impact( info=info, action=p2_action, target=p2_target, result_kind="false_positive" if p2_action != "monitor" else "optimal", strategy=strategy, stage=stage, ) _apply_mission_alignment( info=info, action=p2_action, target=p2_target, result_kind="false_positive" if p2_action != "monitor" else "optimal", mission_profile=mission_profile, ) reward += info["mission_adjustment"] reward -= info["operational_penalty"] return _clamp(reward), False, info # ── Phase 1 accuracy bonuses ─────────────────────────────────────────── if p1_type == strategy: reward += P1_TYPE_BONUS info["p1_type_correct"] = True if p1_target == correct_target: reward += P1_TARGET_BONUS info["p1_target_correct"] = True # ── Catastrophic failure: exfiltration completed ─────────────────────── if (strategy == "exfiltration" and stage == "exfiltration" and p2_action in EXFIL_CATASTROPHIC_ACTIONS): reward = BASE_REWARD + CATASTROPHIC catastrophic = True info["score_reason"] = ( "CATASTROPHIC: Database exfiltration completed — " f"agent used {p2_action} instead of honeypot" ) _apply_operational_impact( info=info, action=p2_action, target=p2_target, result_kind="wrong", strategy=strategy, stage=stage, ) _apply_mission_alignment( info=info, action=p2_action, target=p2_target, result_kind="wrong", mission_profile=mission_profile, ) reward += info["mission_adjustment"] return _clamp(reward), catastrophic, info # ── Stage escalation penalty ─────────────────────────────────────────── if stage == "exploit" and consecutive_wrong >= 1: reward += STAGE_ESCALATION elif stage == "exfiltration" and consecutive_wrong >= 2: reward += STAGE_ESCALATION * 2 # ── Phase 2 action grading ───────────────────────────────────────────── optimal = correct_action heavy = "" if contextual_countermeasure else HEAVY_ACTION.get(strategy, "") if heavy == optimal: heavy = "" requires_tool_verification = ( not is_benign and strategy in OPTIMAL_ACTION and ( task_name == "polymorphic-zero-day" or (task_name == "dual-pivot" and strategy == "lateral_movement") or (task_name == "direct-triage" and strategy == "brute_force") ) ) required_tools = _required_tool_fusion(task_name=task_name, strategy=strategy) tool_evidence_found, fusion_found = _has_relevant_tool_evidence( tool_context=tool_context, strategy=strategy, target=correct_target, required_tools=required_tools, ) info["tool_verification_required"] = requires_tool_verification info["tool_evidence_found"] = tool_evidence_found info["tool_fusion_required"] = sorted(required_tools) info["tool_fusion_found"] = sorted(fusion_found) if ( p2_action == optimal and p2_target == correct_target and requires_tool_verification and not tool_evidence_found ): reward += P2_UNVERIFIED result_kind = "unverified" info["score_reason"] = ( f"Unverified correct action: {p2_action} on {p2_target} would help, " f"but {task_name or 'this task'} requires stronger SOC evidence before full credit" ) elif p2_action == optimal and p2_target == correct_target: reward += P2_OPTIMAL result_kind = "optimal" if contextual_countermeasure: info["score_reason"] = ( f"Context-aware optimal: {p2_action} on {p2_target} — " "foothold already established, so deception beats isolation" ) else: info["score_reason"] = ( f"Optimal: {p2_action} on {p2_target} — attack stopped efficiently" ) elif p2_action == optimal and p2_target != correct_target: reward += P2_HEAVY * 0.5 result_kind = "wrong_target" info["score_reason"] = ( f"Right action ({p2_action}) but wrong target " f"(got {p2_target}, needed {correct_target})" ) elif p2_action == heavy and p2_target == correct_target: reward += P2_HEAVY result_kind = "heavy" info["score_reason"] = ( f"Heavy-handed: {p2_action} stopped attack on {p2_target} " f"but caused unnecessary service disruption" ) else: reward += P2_WRONG result_kind = "wrong" info["score_reason"] = ( f"Wrong: {p2_action} on {p2_target} — " f"needed {correct_action} on {correct_target}" ) acted_correctly = p2_action in (optimal, heavy) and p2_target == correct_target info["acted_correctly"] = acted_correctly _apply_operational_impact( info=info, action=p2_action, target=p2_target, result_kind=result_kind, strategy=strategy, stage=stage, ) _apply_mission_alignment( info=info, action=p2_action, target=p2_target, result_kind=result_kind, mission_profile=mission_profile, ) reward += info["mission_adjustment"] reward -= info["operational_penalty"] return _clamp(reward), catastrophic, info def _apply_mission_alignment( info: Dict[str, Any], action: str, target: str, result_kind: str, mission_profile: Dict[str, Any], ) -> None: sla_priority = str(mission_profile.get("sla_priority", "balanced")) primary_asset = str(mission_profile.get("primary_asset", "unknown")) risk_tolerance = str(mission_profile.get("risk_tolerance", "medium")) adjustment = 0.0 alignment = "neutral" if sla_priority == "availability" and action == "isolate" and target == primary_asset: adjustment -= MAX_MISSION_ADJUSTMENT alignment = "sla_violation" elif sla_priority == "availability" and result_kind == "optimal" and action in ("rate_limit", "patch", "monitor"): adjustment += MAX_MISSION_ADJUSTMENT / 2 alignment = "sla_aligned" elif sla_priority == "containment" and result_kind == "optimal" and action in ("honeypot", "isolate", "patch"): adjustment += MAX_MISSION_ADJUSTMENT / 2 alignment = "containment_aligned" elif risk_tolerance == "low" and result_kind in ("wrong", "wrong_target"): adjustment -= MAX_MISSION_ADJUSTMENT / 2 alignment = "risk_misaligned" info["mission_alignment"] = alignment info["mission_adjustment"] = round(adjustment, 2) def _apply_operational_impact( info: Dict[str, Any], action: str, target: str, result_kind: str, strategy: str, stage: str, ) -> None: """ Add deterministic business-impact telemetry and a small bounded penalty. The penalty is intentionally capped at 0.05 so existing learning curves keep their shape while demos can explain service criticality and blast radius. """ criticality = ASSET_CRITICALITY.get(target, ASSET_CRITICALITY["unknown"]) disruption = ACTION_DISRUPTION.get(action, 0.10) dependents = SERVICE_DEPENDENCIES.get(target, []) dependency_factor = min(1.0, 0.15 * len(dependents)) availability = round(min(1.0, disruption * (criticality + dependency_factor)), 2) security = _security_risk(result_kind=result_kind, strategy=strategy, stage=stage) impact = round(min(1.0, availability + security), 2) if result_kind == "optimal": penalty = 0.0 elif result_kind == "unverified": penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY / 2), 2) else: penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY), 2) info["business_impact"] = impact info["availability_impact"] = availability info["security_risk"] = security info["dependency_blast_radius"] = dependents if disruption > 0 else [] info["operational_penalty"] = penalty def _security_risk(result_kind: str, strategy: str, stage: str) -> float: if result_kind in ("optimal", "heavy"): return 0.0 if result_kind == "unverified": return 0.08 if result_kind == "false_positive": return 0.0 stage_risk = { "recon": 0.18, "exploit": 0.32, "exfiltration": 0.50, }.get(stage, 0.20) if strategy == "exfiltration": stage_risk += 0.15 elif strategy == "lateral_movement": stage_risk += 0.08 return round(min(1.0, stage_risk), 2) def _has_relevant_tool_evidence( tool_context: Dict[str, Any], strategy: str, target: str, required_tools: set[str], ) -> Tuple[bool, set[str]]: fusion_found = { str(result.get("tool", "")) for result in tool_context.get("tool_results", []) or [] if str(result.get("node", "")) == target } has_attack_evidence = False for evidence in tool_context.get("evidence", []) or []: if ( str(evidence.get("evidence_type", "")) == strategy and str(evidence.get("node", "")) == target and bool(evidence.get("verified", False)) ): has_attack_evidence = True break return has_attack_evidence and required_tools.issubset(fusion_found), fusion_found def _required_tool_fusion(task_name: str, strategy: str) -> set[str]: task_rules = TASK_REQUIRED_TOOL_FUSION.get(task_name, {}) if strategy in task_rules: return set(task_rules[strategy]) return set(BASE_REQUIRED_TOOL_FUSION.get(strategy, set())) def _clamp(value: float) -> float: """Strict bounds: never exactly 0.0 or 1.0.""" return max(0.01, min(0.99, round(value, 2))) def normalize_episode_score(rewards: List[float]) -> float: """ Normalize episode rewards to a single score strictly in (0.01, 0.99). ALWAYS returns a value — never raises, never returns exactly 0 or 1. """ if not rewards: return 0.50 total = sum(rewards) n = len(rewards) # Per-step rewards are clamped before they enter the episode reward list, # so normalization must use the reachable ceiling instead of the raw # unclamped sum of bonuses. Otherwise perfect episodes top out around 0.87. max_step_reward = _clamp( BASE_REWARD + P2_OPTIMAL + P1_TYPE_BONUS + P1_TARGET_BONUS + MAX_MISSION_ADJUSTMENT ) min_step_reward = _clamp(BASE_REWARD + CATASTROPHIC) max_poss = n * max_step_reward min_poss = n * min_step_reward if max_poss == min_poss: return 0.50 raw = (total - min_poss) / (max_poss - min_poss) return _clamp(raw)