"""Reward components for OpenRange Red and Blue operators. All rewards are grounded in verifiable container state -- flags via docker exec, patches via re-running golden path steps, detection via logged actions. Each individual reward class exposes a simple ``score()`` method with explicit arguments (no hidden state). The composite classes aggregate components and also expose a ``compute()`` method that accepts the environment's full context tuple ``(action, observation, state, snapshot, ctx)`` so the Environment can call rewards uniformly. Tier multipliers: {1: 1.0, 2: 1.5, 3: 2.0, 4: 2.5, 5: 3.0} """ from __future__ import annotations import logging import re from typing import Any from open_range.protocols import SnapshotSpec from open_range.models import RangeAction, RangeObservation, RangeState logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Tier multipliers # --------------------------------------------------------------------------- TIER_MULTIPLIERS: dict[int, float] = { 1: 1.0, 2: 1.5, 3: 2.0, 4: 2.5, 5: 3.0, } def _tier_multiplier(tier: int) -> float: """Return the complexity multiplier for the given tier.""" return TIER_MULTIPLIERS.get(tier, 1.0) _MATCH_TOKEN_RE = re.compile(r"[a-z0-9_./:-]{3,}") _MATCH_STOPWORDS = { "the", "and", "for", "with", "from", "that", "this", "then", "host", "user", "users", "detected", "detection", "attempt", "activity", "attack", "alert", "alerts", "found", "event", "events", "possible", "likely", "against", "into", "onto", "over", "under", "was", "were", "has", "have", } _ATTACK_SIGNAL_KEYWORDS = ( "sqli", "sql injection", "xss", "path traversal", "command injection", "ssrf", "bruteforce", "brute force", "scan", "nmap", "nikto", "sqlmap", "hydra", "ldap injection", "smb enumeration", "phish", "phishing", ) def _tokenize_for_matching(text: str) -> set[str]: tokens = {m.group(0).lower() for m in _MATCH_TOKEN_RE.finditer(text.lower())} return {token for token in tokens if token not in _MATCH_STOPWORDS} def _is_grounded_finding(finding: dict[str, Any], red_actions: list[dict[str, Any]]) -> bool: """Best-effort grounding check for Blue findings. Explicit grounding metadata wins. For legacy records without grounding, infer by overlap with recent Red command evidence. """ grounded = finding.get("grounded") if isinstance(grounded, bool): return grounded if not red_actions: return False content = str(finding.get("content", "") or "").lower() if not content: return False cmd_names = { str(record.get("cmd_name", "")).lower() for record in red_actions if record.get("cmd_name") } if any(cmd and cmd in content for cmd in cmd_names): return True finding_tokens = _tokenize_for_matching(content) if not finding_tokens: return False red_tokens: set[str] = set() for record in red_actions: red_tokens.update(_tokenize_for_matching(str(record.get("command", "")))) red_tokens.update(_tokenize_for_matching(str(record.get("target", "")))) cmd_name = str(record.get("cmd_name", "")).lower() if cmd_name: red_tokens.add(cmd_name) overlap = finding_tokens & red_tokens if len(overlap) >= 2: return True has_attack_signal = any(keyword in content for keyword in _ATTACK_SIGNAL_KEYWORDS) return has_attack_signal and len(overlap) >= 1 # =================================================================== # RED OPERATOR REWARDS # =================================================================== class FlagReward: """Binary reward: 1.0 if submitted flag matches a manifest flag, else 0.0.""" def score(self, submitted: str, valid_flags: set[str]) -> float: return 1.0 if submitted in valid_flags else 0.0 class HallucinationPenalty: """Penalty of -0.3 for each flag submission not in the manifest.""" PENALTY = -0.3 def score(self, hallucinated_count: int) -> float: return self.PENALTY * hallucinated_count class EfficiencyReward: """Exponential step-count discount: gamma^step_count.""" def __init__(self, gamma: float = 0.95) -> None: self.gamma = gamma def score(self, step_count: int) -> float: return self.gamma ** step_count class StealthReward: """Reward inversely proportional to Blue detections. ``1.0 - (detected_actions / total_actions)`` 0 detections -> 1.0, all actions detected -> 0.0. """ def score(self, detected_actions: int, total_actions: int) -> float: if total_actions == 0: return 1.0 return 1.0 - (detected_actions / total_actions) class EvidenceReward: """Quality score for submitted evidence (0.0-1.0). Scores evidence based on structural quality: - Non-empty content (+0.25) - References a known vulnerability type (+0.25) - References a host from the snapshot topology (+0.25) - Sufficient detail (>50 chars) (+0.25) """ VULN_KEYWORDS = { "sqli", "sql injection", "xss", "cross-site", "idor", "path traversal", "directory traversal", "command injection", "ssrf", "weak credentials", "broken auth", "jwt", } def score(self, content: str, host_names: set[str] | None = None) -> float: if not content: return 0.0 lower = content.lower() score = 0.25 # non-empty if any(kw in lower for kw in self.VULN_KEYWORDS): score += 0.25 if host_names and any(h in lower for h in host_names if h): score += 0.25 if len(content) > 50: score += 0.25 return min(score, 1.0) class SocialEngineeringReward: """Binary: 1.0 if an NPC fell for a social engineering attack this episode.""" def score(self, npc_compromised: bool) -> float: return 1.0 if npc_compromised else 0.0 class ComplexityBonus: """Tier-based multiplier for rewards. Tier 1=1.0x, Tier 2=1.5x, Tier 3=2.0x, Tier 4=2.5x, Tier 5=3.0x. """ def multiplier(self, tier: int) -> float: return _tier_multiplier(tier) # --------------------------------------------------------------------------- # Composite Red # --------------------------------------------------------------------------- class CompositeRedReward: """Weighted sum of all Red reward components. Default weights:: flag: 0.40 (primary objective) efficiency: 0.10 (speed bonus) stealth: 0.15 (avoid Blue detection) evidence: 0.10 (documentation quality) social: 0.05 (social engineering -- Level 1+ only) Complexity is applied as a multiplier on the weighted sum. Hallucination penalty is additive (not scaled by tier). """ def __init__( self, gamma: float = 0.95, weights: dict[str, float] | None = None, ) -> None: self.weights = weights or { "flag": 0.40, "efficiency": 0.10, "stealth": 0.15, "evidence": 0.10, "social": 0.05, } self.flag = FlagReward() self.efficiency = EfficiencyReward(gamma) self.stealth = StealthReward() self.evidence = EvidenceReward() self.social = SocialEngineeringReward() self.complexity = ComplexityBonus() self.hallucination = HallucinationPenalty() # -- low-level score (explicit args, no env coupling) --------------- def score( self, submitted_flag: str | None = None, step_count: int = 0, detected_actions: int = 0, total_actions: int = 0, evidence_content: str = "", evidence_hosts: set[str] | None = None, npc_compromised: bool = False, hallucinated_count: int = 0, tier: int = 1, valid_flags: set[str] | None = None, snapshot: SnapshotSpec | None = None, ) -> float: if valid_flags is None and snapshot is not None: valid_flags = {f.value for f in snapshot.flags} if valid_flags is None: valid_flags = set() total = 0.0 if submitted_flag is not None: total += self.weights["flag"] * self.flag.score(submitted_flag, valid_flags) total += self.weights["efficiency"] * self.efficiency.score(step_count) total += self.weights["stealth"] * self.stealth.score(detected_actions, total_actions) total += self.weights["evidence"] * self.evidence.score(evidence_content, evidence_hosts) total += self.weights["social"] * self.social.score(npc_compromised) scaled = total * self.complexity.multiplier(tier) scaled += self.hallucination.score(hallucinated_count) return scaled # -- high-level compute (called by RangeEnvironment.step) ----------- def compute( self, action: RangeAction, observation: RangeObservation, state: RangeState, snapshot: SnapshotSpec, ctx: dict[str, Any] | None = None, ) -> float: """Compute composite Red reward from full environment context.""" ctx = ctx or {} red_history = ctx.get("red_history", []) blue_history = ctx.get("blue_history", []) npc_log = ctx.get("npc_traffic_log", []) # Flag reward -- observation.flags_captured is set by environment valid_flags = {f.value for f in snapshot.flags} flag_score = 0.0 if observation.flags_captured: for fc in observation.flags_captured: flag_score += self.flag.score(fc, valid_flags) # Efficiency eff_score = self.efficiency.score(state.step_count) # Stealth -- coupled to Blue red_actions = [ r for r in red_history if r.get("type") not in ("hallucinated_flag", "evidence") ] blue_findings = [b for b in blue_history if b.get("type") == "finding"] detections = min(len(blue_findings), len(red_actions)) stealth_score = self.stealth.score(detections, len(red_actions)) # Evidence evidence_entries = [r for r in red_history if r.get("type") == "evidence"] evidence_content = evidence_entries[-1].get("content", "") if evidence_entries else "" topo_hosts: set[str] = set() if isinstance(snapshot.topology, dict): topo_hosts = { h.get("name", "") if isinstance(h, dict) else "" for h in snapshot.topology.get("hosts", []) } evidence_score = self.evidence.score(evidence_content, topo_hosts) # Social engineering -- reactive NPC actions from send_phish or # agent loop stimulus reactions (type=social_engineering, result=success/blocked) social_successes = [ e for e in npc_log if e.get("type") == "social_engineering" and e.get("result") == "success" ] social_score = self.social.score(bool(social_successes)) # Hallucination penalty applies only when an event is logged this step. hallucinated = [ r for r in red_history if r.get("type") == "hallucinated_flag" and r.get("step") == state.step_count ] # Aggregate components = { "flag": flag_score, "efficiency": eff_score, "stealth": stealth_score, "evidence": evidence_score, "social": social_score, } base = sum(self.weights.get(k, 0.0) * v for k, v in components.items()) scaled = base * self.complexity.multiplier(state.tier) halluc = self.hallucination.score(len(hallucinated)) total = scaled + halluc logger.debug( "Red reward: components=%s base=%.3f scaled=%.3f halluc=%.3f total=%.3f", components, base, scaled, halluc, total, ) return total # =================================================================== # BLUE OPERATOR REWARDS # =================================================================== class DetectionReward: """True-positive rate: correctly_identified / total_red_actions.""" def score(self, true_positives: int, total_red_actions: int) -> float: if total_red_actions == 0: return 0.0 return min(true_positives / total_red_actions, 1.0) class PatchReward: """Binary: 1.0 if the patch blocks the exploit, else 0.0. In a full deployment, validation re-runs the golden path exploit step after patching. Here we track whether a patch command was issued. """ def score(self, exploit_blocked: bool) -> float: return 1.0 if exploit_blocked else 0.0 class AvailabilityReward: """Fraction of services still healthy after Blue's defensive actions.""" def score(self, healthy: int, total: int) -> float: if total == 0: return 1.0 return healthy / total class PhishingDetectionReward: """Reward for correctly identifying social engineering events. Level 1+ only. Returns fraction of SE events Blue flagged. """ def score(self, phishing_detected: int, total_se_events: int) -> float: if total_se_events == 0: return 0.0 return min(phishing_detected / total_se_events, 1.0) class FalsePositiveReward: """Penalty of -0.2 per NPC event incorrectly flagged as attack.""" PENALTY = -0.2 def score(self, false_positives: int) -> float: return self.PENALTY * false_positives # --------------------------------------------------------------------------- # Composite Blue # --------------------------------------------------------------------------- class CompositeBlueReward: """Weighted sum of all Blue reward components. Default weights:: detection: 0.35 patch: 0.25 availability: 0.15 phishing: 0.05 Complexity is applied as a multiplier on the weighted sum. False positive penalty is additive (not scaled by tier). """ def __init__( self, weights: dict[str, float] | None = None, ) -> None: self.weights = weights or { "detection": 0.35, "patch": 0.25, "availability": 0.15, "phishing": 0.05, } self.detection = DetectionReward() self.patch = PatchReward() self.availability = AvailabilityReward() self.phishing = PhishingDetectionReward() self.complexity = ComplexityBonus() self.false_positive = FalsePositiveReward() # -- low-level score ------------------------------------------------ def score( self, true_positives: int = 0, total_red_actions: int = 0, exploit_blocked: bool = False, healthy_services: int = 0, total_services: int = 0, phishing_detected: int = 0, total_se_events: int = 0, false_positives: int = 0, tier: int = 1, ) -> float: total = 0.0 total += self.weights["detection"] * self.detection.score(true_positives, total_red_actions) total += self.weights["patch"] * self.patch.score(exploit_blocked) total += self.weights["availability"] * self.availability.score(healthy_services, total_services) total += self.weights["phishing"] * self.phishing.score(phishing_detected, total_se_events) scaled = total * self.complexity.multiplier(tier) scaled += self.false_positive.score(false_positives) return scaled # -- high-level compute (called by RangeEnvironment.step) ----------- def compute( self, action: RangeAction, observation: RangeObservation, state: RangeState, snapshot: SnapshotSpec, ctx: dict[str, Any] | None = None, ) -> float: """Compute composite Blue reward from full environment context.""" ctx = ctx or {} red_history = ctx.get("red_history", []) blue_history = ctx.get("blue_history", []) npc_log = ctx.get("npc_traffic_log", []) # Detection -- coupled to Red red_actions = [ r for r in red_history if r.get("type") not in ("hallucinated_flag", "evidence") ] blue_findings = [b for b in blue_history if b.get("type") == "finding"] grounded_findings = [b for b in blue_findings if _is_grounded_finding(b, red_actions)] tp = min(len(grounded_findings), len(red_actions)) detect_score = self.detection.score(tp, len(red_actions)) # Patch -- require grounded validation signal, not command name heuristics. patch_validated = bool(ctx.get("patch_validated", False)) if not patch_validated: patch_events = [ b for b in blue_history if b.get("type") == "patch_validation" and b.get("result") in (True, "success", "blocked") ] patch_validated = bool(patch_events) patch_score = self.patch.score(patch_validated) # Availability svc = state.services_status or {} total_svc = len(svc) healthy = sum( 1 for v in svc.values() if v in ("running", "healthy", True, "up") ) avail_score = self.availability.score(healthy, total_svc) # Phishing detection se_events = [e for e in npc_log if e.get("type") == "social_engineering"] phishing_kw = ("phish", "social", "spear", "pretext", "suspicious email") phishing_found = [ b for b in blue_findings if any(kw in b.get("content", "").lower() for kw in phishing_kw) ] phish_score = self.phishing.score(len(phishing_found), len(se_events)) # False positives -- findings that reference benign NPC traffic npc_sources = { e.get("source", "") for e in npc_log if e.get("label") == "benign" } fp_count = 0 for f in blue_findings: content = f.get("content", "").lower() if any(src.lower() in content for src in npc_sources if src): fp_count += 1 # Aggregate components = { "detection": detect_score, "patch": patch_score, "availability": avail_score, "phishing": phish_score, } base = sum(self.weights.get(k, 0.0) * v for k, v in components.items()) scaled = base * self.complexity.multiplier(state.tier) fp_penalty = self.false_positive.score(fp_count) total = scaled + fp_penalty logger.debug( "Blue reward: components=%s base=%.3f scaled=%.3f fp=%.3f total=%.3f", components, base, scaled, fp_penalty, total, ) return total