adaptshield / server /grader.py
SaiManish123's picture
Initial deploy of AdaptShield two-phase cybersecurity environment
c1060df verified
"""
AdaptShield Grader
Fully deterministic reward computation.
No NLP, no LLM-as-judge, no external API calls.
Runs in milliseconds.
INVARIANT: All return values are strictly in [0.01, 0.99].
normalized_score is ALWAYS present in metadata — never missing.
"""
from typing import Any, Dict, List, Tuple
# ── Reward constants ───────────────────────────────────────────────────────
BASE_REWARD = 0.50 # survival — agent gets this for staying alive
P1_TYPE_BONUS = 0.15 # Phase 1 correctly identified threat type
P1_TARGET_BONUS = 0.10 # Phase 1 correctly identified target node
P2_OPTIMAL = 0.39 # correct + efficient action (0.50+0.39 = 0.89 base max)
P2_HEAVY = 0.18 # correct action but heavy-handed
P2_UNVERIFIED = 0.10 # right hard-task action without investigative evidence
P2_WRONG = -0.25 # wrong action on real threat
FALSE_POSITIVE = -0.39 # acted aggressively on benign event
STAGE_ESCALATION = -0.10 # penalty applied when attack stage escalates
CATASTROPHIC = -0.49 # database exfiltration completed → done=True
# Optimal vs heavy-handed action mapping
OPTIMAL_ACTION = {
"brute_force": "rate_limit",
"lateral_movement": "isolate",
"exfiltration": "honeypot",
"supply_chain": "patch",
}
HEAVY_ACTION = {
"brute_force": "isolate", # works but unnecessary downtime
"lateral_movement": "honeypot", # works but slower
"exfiltration": "isolate", # works but database goes offline
"supply_chain": "isolate", # works but gateway goes offline
}
# Catastrophic: exfiltration succeeds if agent uses wrong action 3+ times
EXFIL_CATASTROPHIC_ACTIONS = ["monitor", "rate_limit", "patch"]
# Operational model: criticality and dependency blast radius.
# This is intentionally lightweight so it strengthens realism without making
# rewards noisy or hard to train against.
ASSET_CRITICALITY = {
"auth_service": 0.70,
"payment_service": 0.90,
"database": 1.00,
"api_gateway": 0.80,
"none": 0.00,
"unknown": 0.50,
}
SERVICE_DEPENDENCIES = {
"auth_service": ["payment_service"],
"payment_service": ["api_gateway"],
"database": ["payment_service", "api_gateway"],
"api_gateway": ["auth_service", "payment_service", "database"],
"none": [],
"unknown": [],
}
ACTION_DISRUPTION = {
"monitor": 0.00,
"patch": 0.06,
"rate_limit": 0.10,
"honeypot": 0.12,
"isolate": 0.35,
}
MAX_OPERATIONAL_PENALTY = 0.05
MAX_MISSION_ADJUSTMENT = 0.04
BASE_REQUIRED_TOOL_FUSION = {
"brute_force": {"log_search", "cmdb_lookup"},
"lateral_movement": {"edr_status", "log_search"},
"exfiltration": {"log_search", "edr_status"},
"supply_chain": {"vuln_lookup", "log_search"},
}
TASK_REQUIRED_TOOL_FUSION = {
"direct-triage": {
"brute_force": {"log_search"},
},
"dual-pivot": {
"lateral_movement": {"edr_status", "log_search", "identity_lookup"},
},
"polymorphic-zero-day": {
"brute_force": {"log_search", "cmdb_lookup", "identity_lookup"},
"lateral_movement": {"edr_status", "log_search", "identity_lookup", "cmdb_lookup"},
"exfiltration": {"log_search", "edr_status", "netflow_lookup", "cmdb_lookup"},
"supply_chain": {"vuln_lookup", "log_search", "change_calendar_lookup", "cmdb_lookup"},
},
}
def grade_step(
phase1_action: Dict[str, Any],
phase2_action: Dict[str, Any],
turn_config: Dict[str, Any],
stage: str,
consecutive_wrong: int,
task_name: str = "",
foothold_established: bool = False,
mission_profile: Dict[str, Any] | None = None,
tool_context: Dict[str, Any] | None = None,
) -> Tuple[float, bool, Dict[str, Any]]:
"""
Grade a complete two-phase step.
Args:
phase1_action: Agent's Phase 1 output (threat assessment)
phase2_action: Agent's Phase 2 output (defensive action)
turn_config: Ground truth from AttackerEngine.build_observation()
stage: Current attack stage (recon/exploit/exfiltration)
consecutive_wrong: How many consecutive wrong actions agent has taken
Returns:
(reward, catastrophic_failure, info_dict)
reward is ALWAYS in range [0.01, 0.99]
"""
is_benign = turn_config.get("is_benign", False)
strategy = turn_config.get("strategy", "benign")
correct_action = turn_config.get("correct_action", "monitor")
correct_target = turn_config.get("correct_target", "none")
mission_profile = mission_profile or {}
tool_context = tool_context or {}
contextual_countermeasure = (
task_name == "polymorphic-zero-day" and
foothold_established and
strategy == "lateral_movement"
)
if contextual_countermeasure:
correct_action = "honeypot"
p1_type = str(phase1_action.get("threat_type", ""))
p1_target = str(phase1_action.get("target_node", ""))
p2_action = str(phase2_action.get("action", "monitor"))
p2_target = str(phase2_action.get("target_node", ""))
reward = BASE_REWARD
catastrophic = False
info = {
"strategy": strategy,
"stage": stage,
"correct_action": correct_action,
"correct_target": correct_target,
"p1_threat_type": p1_type,
"p1_target": p1_target,
"p2_action": p2_action,
"p2_target": p2_target,
"is_benign": is_benign,
"p1_type_correct": False,
"p1_target_correct": False,
"score_reason": "",
"business_impact": 0.0,
"availability_impact": 0.0,
"security_risk": 0.0,
"dependency_blast_radius": [],
"operational_penalty": 0.0,
"foothold_established": foothold_established,
"contextual_countermeasure": contextual_countermeasure,
"mission_profile": mission_profile,
"mission_alignment": "neutral",
"mission_adjustment": 0.0,
"tool_verification_required": False,
"tool_evidence_found": False,
"tool_fusion_required": [],
"tool_fusion_found": [],
"tool_count": int(tool_context.get("tool_count", 0) or 0),
}
# ── False positive path ────────────────────────────────────────────────
if is_benign:
if p2_action in ("rate_limit", "isolate", "honeypot", "patch"):
reward += FALSE_POSITIVE
info["acted_correctly"] = False
info["score_reason"] = (
f"False positive: used {p2_action} on benign event"
)
else:
reward += 0.15
info["acted_correctly"] = True
info["score_reason"] = "Correct: monitored benign event without overreacting"
_apply_operational_impact(
info=info,
action=p2_action,
target=p2_target,
result_kind="false_positive" if p2_action != "monitor" else "optimal",
strategy=strategy,
stage=stage,
)
_apply_mission_alignment(
info=info,
action=p2_action,
target=p2_target,
result_kind="false_positive" if p2_action != "monitor" else "optimal",
mission_profile=mission_profile,
)
reward += info["mission_adjustment"]
reward -= info["operational_penalty"]
return _clamp(reward), False, info
# ── Phase 1 accuracy bonuses ───────────────────────────────────────────
if p1_type == strategy:
reward += P1_TYPE_BONUS
info["p1_type_correct"] = True
if p1_target == correct_target:
reward += P1_TARGET_BONUS
info["p1_target_correct"] = True
# ── Catastrophic failure: exfiltration completed ───────────────────────
if (strategy == "exfiltration" and
stage == "exfiltration" and
p2_action in EXFIL_CATASTROPHIC_ACTIONS):
reward = BASE_REWARD + CATASTROPHIC
catastrophic = True
info["score_reason"] = (
"CATASTROPHIC: Database exfiltration completed — "
f"agent used {p2_action} instead of honeypot"
)
_apply_operational_impact(
info=info,
action=p2_action,
target=p2_target,
result_kind="wrong",
strategy=strategy,
stage=stage,
)
_apply_mission_alignment(
info=info,
action=p2_action,
target=p2_target,
result_kind="wrong",
mission_profile=mission_profile,
)
reward += info["mission_adjustment"]
return _clamp(reward), catastrophic, info
# ── Stage escalation penalty ───────────────────────────────────────────
if stage == "exploit" and consecutive_wrong >= 1:
reward += STAGE_ESCALATION
elif stage == "exfiltration" and consecutive_wrong >= 2:
reward += STAGE_ESCALATION * 2
# ── Phase 2 action grading ─────────────────────────────────────────────
optimal = correct_action
heavy = "" if contextual_countermeasure else HEAVY_ACTION.get(strategy, "")
if heavy == optimal:
heavy = ""
requires_tool_verification = (
not is_benign and
strategy in OPTIMAL_ACTION and
(
task_name == "polymorphic-zero-day" or
(task_name == "dual-pivot" and strategy == "lateral_movement") or
(task_name == "direct-triage" and strategy == "brute_force")
)
)
required_tools = _required_tool_fusion(task_name=task_name, strategy=strategy)
tool_evidence_found, fusion_found = _has_relevant_tool_evidence(
tool_context=tool_context,
strategy=strategy,
target=correct_target,
required_tools=required_tools,
)
info["tool_verification_required"] = requires_tool_verification
info["tool_evidence_found"] = tool_evidence_found
info["tool_fusion_required"] = sorted(required_tools)
info["tool_fusion_found"] = sorted(fusion_found)
if (
p2_action == optimal and
p2_target == correct_target and
requires_tool_verification and
not tool_evidence_found
):
reward += P2_UNVERIFIED
result_kind = "unverified"
info["score_reason"] = (
f"Unverified correct action: {p2_action} on {p2_target} would help, "
f"but {task_name or 'this task'} requires stronger SOC evidence before full credit"
)
elif p2_action == optimal and p2_target == correct_target:
reward += P2_OPTIMAL
result_kind = "optimal"
if contextual_countermeasure:
info["score_reason"] = (
f"Context-aware optimal: {p2_action} on {p2_target} — "
"foothold already established, so deception beats isolation"
)
else:
info["score_reason"] = (
f"Optimal: {p2_action} on {p2_target} — attack stopped efficiently"
)
elif p2_action == optimal and p2_target != correct_target:
reward += P2_HEAVY * 0.5
result_kind = "wrong_target"
info["score_reason"] = (
f"Right action ({p2_action}) but wrong target "
f"(got {p2_target}, needed {correct_target})"
)
elif p2_action == heavy and p2_target == correct_target:
reward += P2_HEAVY
result_kind = "heavy"
info["score_reason"] = (
f"Heavy-handed: {p2_action} stopped attack on {p2_target} "
f"but caused unnecessary service disruption"
)
else:
reward += P2_WRONG
result_kind = "wrong"
info["score_reason"] = (
f"Wrong: {p2_action} on {p2_target} — "
f"needed {correct_action} on {correct_target}"
)
acted_correctly = p2_action in (optimal, heavy) and p2_target == correct_target
info["acted_correctly"] = acted_correctly
_apply_operational_impact(
info=info,
action=p2_action,
target=p2_target,
result_kind=result_kind,
strategy=strategy,
stage=stage,
)
_apply_mission_alignment(
info=info,
action=p2_action,
target=p2_target,
result_kind=result_kind,
mission_profile=mission_profile,
)
reward += info["mission_adjustment"]
reward -= info["operational_penalty"]
return _clamp(reward), catastrophic, info
def _apply_mission_alignment(
info: Dict[str, Any],
action: str,
target: str,
result_kind: str,
mission_profile: Dict[str, Any],
) -> None:
sla_priority = str(mission_profile.get("sla_priority", "balanced"))
primary_asset = str(mission_profile.get("primary_asset", "unknown"))
risk_tolerance = str(mission_profile.get("risk_tolerance", "medium"))
adjustment = 0.0
alignment = "neutral"
if sla_priority == "availability" and action == "isolate" and target == primary_asset:
adjustment -= MAX_MISSION_ADJUSTMENT
alignment = "sla_violation"
elif sla_priority == "availability" and result_kind == "optimal" and action in ("rate_limit", "patch", "monitor"):
adjustment += MAX_MISSION_ADJUSTMENT / 2
alignment = "sla_aligned"
elif sla_priority == "containment" and result_kind == "optimal" and action in ("honeypot", "isolate", "patch"):
adjustment += MAX_MISSION_ADJUSTMENT / 2
alignment = "containment_aligned"
elif risk_tolerance == "low" and result_kind in ("wrong", "wrong_target"):
adjustment -= MAX_MISSION_ADJUSTMENT / 2
alignment = "risk_misaligned"
info["mission_alignment"] = alignment
info["mission_adjustment"] = round(adjustment, 2)
def _apply_operational_impact(
info: Dict[str, Any],
action: str,
target: str,
result_kind: str,
strategy: str,
stage: str,
) -> None:
"""
Add deterministic business-impact telemetry and a small bounded penalty.
The penalty is intentionally capped at 0.05 so existing learning curves keep
their shape while demos can explain service criticality and blast radius.
"""
criticality = ASSET_CRITICALITY.get(target, ASSET_CRITICALITY["unknown"])
disruption = ACTION_DISRUPTION.get(action, 0.10)
dependents = SERVICE_DEPENDENCIES.get(target, [])
dependency_factor = min(1.0, 0.15 * len(dependents))
availability = round(min(1.0, disruption * (criticality + dependency_factor)), 2)
security = _security_risk(result_kind=result_kind, strategy=strategy, stage=stage)
impact = round(min(1.0, availability + security), 2)
if result_kind == "optimal":
penalty = 0.0
elif result_kind == "unverified":
penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY / 2), 2)
else:
penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY), 2)
info["business_impact"] = impact
info["availability_impact"] = availability
info["security_risk"] = security
info["dependency_blast_radius"] = dependents if disruption > 0 else []
info["operational_penalty"] = penalty
def _security_risk(result_kind: str, strategy: str, stage: str) -> float:
if result_kind in ("optimal", "heavy"):
return 0.0
if result_kind == "unverified":
return 0.08
if result_kind == "false_positive":
return 0.0
stage_risk = {
"recon": 0.18,
"exploit": 0.32,
"exfiltration": 0.50,
}.get(stage, 0.20)
if strategy == "exfiltration":
stage_risk += 0.15
elif strategy == "lateral_movement":
stage_risk += 0.08
return round(min(1.0, stage_risk), 2)
def _has_relevant_tool_evidence(
tool_context: Dict[str, Any],
strategy: str,
target: str,
required_tools: set[str],
) -> Tuple[bool, set[str]]:
fusion_found = {
str(result.get("tool", ""))
for result in tool_context.get("tool_results", []) or []
if str(result.get("node", "")) == target
}
has_attack_evidence = False
for evidence in tool_context.get("evidence", []) or []:
if (
str(evidence.get("evidence_type", "")) == strategy and
str(evidence.get("node", "")) == target and
bool(evidence.get("verified", False))
):
has_attack_evidence = True
break
return has_attack_evidence and required_tools.issubset(fusion_found), fusion_found
def _required_tool_fusion(task_name: str, strategy: str) -> set[str]:
task_rules = TASK_REQUIRED_TOOL_FUSION.get(task_name, {})
if strategy in task_rules:
return set(task_rules[strategy])
return set(BASE_REQUIRED_TOOL_FUSION.get(strategy, set()))
def _clamp(value: float) -> float:
"""Strict bounds: never exactly 0.0 or 1.0."""
return max(0.01, min(0.99, round(value, 2)))
def normalize_episode_score(rewards: List[float]) -> float:
"""
Normalize episode rewards to a single score strictly in (0.01, 0.99).
ALWAYS returns a value — never raises, never returns exactly 0 or 1.
"""
if not rewards:
return 0.50
total = sum(rewards)
n = len(rewards)
# Per-step rewards are clamped before they enter the episode reward list,
# so normalization must use the reachable ceiling instead of the raw
# unclamped sum of bonuses. Otherwise perfect episodes top out around 0.87.
max_step_reward = _clamp(
BASE_REWARD + P2_OPTIMAL + P1_TYPE_BONUS + P1_TARGET_BONUS + MAX_MISSION_ADJUSTMENT
)
min_step_reward = _clamp(BASE_REWARD + CATASTROPHIC)
max_poss = n * max_step_reward
min_poss = n * min_step_reward
if max_poss == min_poss:
return 0.50
raw = (total - min_poss) / (max_poss - min_poss)
return _clamp(raw)