open-range / src /open_range /server /rewards.py
Lars Talian
Ground Red/Blue rewards and finding validation
c51ff20
"""Reward components for OpenRange Red and Blue operators.
All rewards are grounded in verifiable container state -- flags via docker exec,
patches via re-running golden path steps, detection via logged actions.
Each individual reward class exposes a simple ``score()`` method with explicit
arguments (no hidden state). The composite classes aggregate components and
also expose a ``compute()`` method that accepts the environment's full context
tuple ``(action, observation, state, snapshot, ctx)`` so the Environment can
call rewards uniformly.
Tier multipliers: {1: 1.0, 2: 1.5, 3: 2.0, 4: 2.5, 5: 3.0}
"""
from __future__ import annotations
import logging
import re
from typing import Any
from open_range.protocols import SnapshotSpec
from open_range.models import RangeAction, RangeObservation, RangeState
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Tier multipliers
# ---------------------------------------------------------------------------
TIER_MULTIPLIERS: dict[int, float] = {
1: 1.0,
2: 1.5,
3: 2.0,
4: 2.5,
5: 3.0,
}
def _tier_multiplier(tier: int) -> float:
"""Return the complexity multiplier for the given tier."""
return TIER_MULTIPLIERS.get(tier, 1.0)
_MATCH_TOKEN_RE = re.compile(r"[a-z0-9_./:-]{3,}")
_MATCH_STOPWORDS = {
"the", "and", "for", "with", "from", "that", "this", "then", "host",
"user", "users", "detected", "detection", "attempt", "activity", "attack",
"alert", "alerts", "found", "event", "events", "possible", "likely",
"against", "into", "onto", "over", "under", "was", "were", "has", "have",
}
_ATTACK_SIGNAL_KEYWORDS = (
"sqli", "sql injection", "xss", "path traversal", "command injection",
"ssrf", "bruteforce", "brute force", "scan", "nmap", "nikto", "sqlmap",
"hydra", "ldap injection", "smb enumeration", "phish", "phishing",
)
def _tokenize_for_matching(text: str) -> set[str]:
tokens = {m.group(0).lower() for m in _MATCH_TOKEN_RE.finditer(text.lower())}
return {token for token in tokens if token not in _MATCH_STOPWORDS}
def _is_grounded_finding(finding: dict[str, Any], red_actions: list[dict[str, Any]]) -> bool:
"""Best-effort grounding check for Blue findings.
Explicit grounding metadata wins. For legacy records without grounding,
infer by overlap with recent Red command evidence.
"""
grounded = finding.get("grounded")
if isinstance(grounded, bool):
return grounded
if not red_actions:
return False
content = str(finding.get("content", "") or "").lower()
if not content:
return False
cmd_names = {
str(record.get("cmd_name", "")).lower()
for record in red_actions
if record.get("cmd_name")
}
if any(cmd and cmd in content for cmd in cmd_names):
return True
finding_tokens = _tokenize_for_matching(content)
if not finding_tokens:
return False
red_tokens: set[str] = set()
for record in red_actions:
red_tokens.update(_tokenize_for_matching(str(record.get("command", ""))))
red_tokens.update(_tokenize_for_matching(str(record.get("target", ""))))
cmd_name = str(record.get("cmd_name", "")).lower()
if cmd_name:
red_tokens.add(cmd_name)
overlap = finding_tokens & red_tokens
if len(overlap) >= 2:
return True
has_attack_signal = any(keyword in content for keyword in _ATTACK_SIGNAL_KEYWORDS)
return has_attack_signal and len(overlap) >= 1
# ===================================================================
# RED OPERATOR REWARDS
# ===================================================================
class FlagReward:
"""Binary reward: 1.0 if submitted flag matches a manifest flag, else 0.0."""
def score(self, submitted: str, valid_flags: set[str]) -> float:
return 1.0 if submitted in valid_flags else 0.0
class HallucinationPenalty:
"""Penalty of -0.3 for each flag submission not in the manifest."""
PENALTY = -0.3
def score(self, hallucinated_count: int) -> float:
return self.PENALTY * hallucinated_count
class EfficiencyReward:
"""Exponential step-count discount: gamma^step_count."""
def __init__(self, gamma: float = 0.95) -> None:
self.gamma = gamma
def score(self, step_count: int) -> float:
return self.gamma ** step_count
class StealthReward:
"""Reward inversely proportional to Blue detections.
``1.0 - (detected_actions / total_actions)``
0 detections -> 1.0, all actions detected -> 0.0.
"""
def score(self, detected_actions: int, total_actions: int) -> float:
if total_actions == 0:
return 1.0
return 1.0 - (detected_actions / total_actions)
class EvidenceReward:
"""Quality score for submitted evidence (0.0-1.0).
Scores evidence based on structural quality:
- Non-empty content (+0.25)
- References a known vulnerability type (+0.25)
- References a host from the snapshot topology (+0.25)
- Sufficient detail (>50 chars) (+0.25)
"""
VULN_KEYWORDS = {
"sqli", "sql injection", "xss", "cross-site", "idor",
"path traversal", "directory traversal", "command injection",
"ssrf", "weak credentials", "broken auth", "jwt",
}
def score(self, content: str, host_names: set[str] | None = None) -> float:
if not content:
return 0.0
lower = content.lower()
score = 0.25 # non-empty
if any(kw in lower for kw in self.VULN_KEYWORDS):
score += 0.25
if host_names and any(h in lower for h in host_names if h):
score += 0.25
if len(content) > 50:
score += 0.25
return min(score, 1.0)
class SocialEngineeringReward:
"""Binary: 1.0 if an NPC fell for a social engineering attack this episode."""
def score(self, npc_compromised: bool) -> float:
return 1.0 if npc_compromised else 0.0
class ComplexityBonus:
"""Tier-based multiplier for rewards.
Tier 1=1.0x, Tier 2=1.5x, Tier 3=2.0x, Tier 4=2.5x, Tier 5=3.0x.
"""
def multiplier(self, tier: int) -> float:
return _tier_multiplier(tier)
# ---------------------------------------------------------------------------
# Composite Red
# ---------------------------------------------------------------------------
class CompositeRedReward:
"""Weighted sum of all Red reward components.
Default weights::
flag: 0.40 (primary objective)
efficiency: 0.10 (speed bonus)
stealth: 0.15 (avoid Blue detection)
evidence: 0.10 (documentation quality)
social: 0.05 (social engineering -- Level 1+ only)
Complexity is applied as a multiplier on the weighted sum.
Hallucination penalty is additive (not scaled by tier).
"""
def __init__(
self,
gamma: float = 0.95,
weights: dict[str, float] | None = None,
) -> None:
self.weights = weights or {
"flag": 0.40,
"efficiency": 0.10,
"stealth": 0.15,
"evidence": 0.10,
"social": 0.05,
}
self.flag = FlagReward()
self.efficiency = EfficiencyReward(gamma)
self.stealth = StealthReward()
self.evidence = EvidenceReward()
self.social = SocialEngineeringReward()
self.complexity = ComplexityBonus()
self.hallucination = HallucinationPenalty()
# -- low-level score (explicit args, no env coupling) ---------------
def score(
self,
submitted_flag: str | None = None,
step_count: int = 0,
detected_actions: int = 0,
total_actions: int = 0,
evidence_content: str = "",
evidence_hosts: set[str] | None = None,
npc_compromised: bool = False,
hallucinated_count: int = 0,
tier: int = 1,
valid_flags: set[str] | None = None,
snapshot: SnapshotSpec | None = None,
) -> float:
if valid_flags is None and snapshot is not None:
valid_flags = {f.value for f in snapshot.flags}
if valid_flags is None:
valid_flags = set()
total = 0.0
if submitted_flag is not None:
total += self.weights["flag"] * self.flag.score(submitted_flag, valid_flags)
total += self.weights["efficiency"] * self.efficiency.score(step_count)
total += self.weights["stealth"] * self.stealth.score(detected_actions, total_actions)
total += self.weights["evidence"] * self.evidence.score(evidence_content, evidence_hosts)
total += self.weights["social"] * self.social.score(npc_compromised)
scaled = total * self.complexity.multiplier(tier)
scaled += self.hallucination.score(hallucinated_count)
return scaled
# -- high-level compute (called by RangeEnvironment.step) -----------
def compute(
self,
action: RangeAction,
observation: RangeObservation,
state: RangeState,
snapshot: SnapshotSpec,
ctx: dict[str, Any] | None = None,
) -> float:
"""Compute composite Red reward from full environment context."""
ctx = ctx or {}
red_history = ctx.get("red_history", [])
blue_history = ctx.get("blue_history", [])
npc_log = ctx.get("npc_traffic_log", [])
# Flag reward -- observation.flags_captured is set by environment
valid_flags = {f.value for f in snapshot.flags}
flag_score = 0.0
if observation.flags_captured:
for fc in observation.flags_captured:
flag_score += self.flag.score(fc, valid_flags)
# Efficiency
eff_score = self.efficiency.score(state.step_count)
# Stealth -- coupled to Blue
red_actions = [
r for r in red_history
if r.get("type") not in ("hallucinated_flag", "evidence")
]
blue_findings = [b for b in blue_history if b.get("type") == "finding"]
detections = min(len(blue_findings), len(red_actions))
stealth_score = self.stealth.score(detections, len(red_actions))
# Evidence
evidence_entries = [r for r in red_history if r.get("type") == "evidence"]
evidence_content = evidence_entries[-1].get("content", "") if evidence_entries else ""
topo_hosts: set[str] = set()
if isinstance(snapshot.topology, dict):
topo_hosts = {
h.get("name", "") if isinstance(h, dict) else ""
for h in snapshot.topology.get("hosts", [])
}
evidence_score = self.evidence.score(evidence_content, topo_hosts)
# Social engineering -- reactive NPC actions from send_phish or
# agent loop stimulus reactions (type=social_engineering, result=success/blocked)
social_successes = [
e for e in npc_log
if e.get("type") == "social_engineering" and e.get("result") == "success"
]
social_score = self.social.score(bool(social_successes))
# Hallucination penalty applies only when an event is logged this step.
hallucinated = [
r
for r in red_history
if r.get("type") == "hallucinated_flag"
and r.get("step") == state.step_count
]
# Aggregate
components = {
"flag": flag_score,
"efficiency": eff_score,
"stealth": stealth_score,
"evidence": evidence_score,
"social": social_score,
}
base = sum(self.weights.get(k, 0.0) * v for k, v in components.items())
scaled = base * self.complexity.multiplier(state.tier)
halluc = self.hallucination.score(len(hallucinated))
total = scaled + halluc
logger.debug(
"Red reward: components=%s base=%.3f scaled=%.3f halluc=%.3f total=%.3f",
components, base, scaled, halluc, total,
)
return total
# ===================================================================
# BLUE OPERATOR REWARDS
# ===================================================================
class DetectionReward:
"""True-positive rate: correctly_identified / total_red_actions."""
def score(self, true_positives: int, total_red_actions: int) -> float:
if total_red_actions == 0:
return 0.0
return min(true_positives / total_red_actions, 1.0)
class PatchReward:
"""Binary: 1.0 if the patch blocks the exploit, else 0.0.
In a full deployment, validation re-runs the golden path exploit step
after patching. Here we track whether a patch command was issued.
"""
def score(self, exploit_blocked: bool) -> float:
return 1.0 if exploit_blocked else 0.0
class AvailabilityReward:
"""Fraction of services still healthy after Blue's defensive actions."""
def score(self, healthy: int, total: int) -> float:
if total == 0:
return 1.0
return healthy / total
class PhishingDetectionReward:
"""Reward for correctly identifying social engineering events.
Level 1+ only. Returns fraction of SE events Blue flagged.
"""
def score(self, phishing_detected: int, total_se_events: int) -> float:
if total_se_events == 0:
return 0.0
return min(phishing_detected / total_se_events, 1.0)
class FalsePositiveReward:
"""Penalty of -0.2 per NPC event incorrectly flagged as attack."""
PENALTY = -0.2
def score(self, false_positives: int) -> float:
return self.PENALTY * false_positives
# ---------------------------------------------------------------------------
# Composite Blue
# ---------------------------------------------------------------------------
class CompositeBlueReward:
"""Weighted sum of all Blue reward components.
Default weights::
detection: 0.35
patch: 0.25
availability: 0.15
phishing: 0.05
Complexity is applied as a multiplier on the weighted sum.
False positive penalty is additive (not scaled by tier).
"""
def __init__(
self,
weights: dict[str, float] | None = None,
) -> None:
self.weights = weights or {
"detection": 0.35,
"patch": 0.25,
"availability": 0.15,
"phishing": 0.05,
}
self.detection = DetectionReward()
self.patch = PatchReward()
self.availability = AvailabilityReward()
self.phishing = PhishingDetectionReward()
self.complexity = ComplexityBonus()
self.false_positive = FalsePositiveReward()
# -- low-level score ------------------------------------------------
def score(
self,
true_positives: int = 0,
total_red_actions: int = 0,
exploit_blocked: bool = False,
healthy_services: int = 0,
total_services: int = 0,
phishing_detected: int = 0,
total_se_events: int = 0,
false_positives: int = 0,
tier: int = 1,
) -> float:
total = 0.0
total += self.weights["detection"] * self.detection.score(true_positives, total_red_actions)
total += self.weights["patch"] * self.patch.score(exploit_blocked)
total += self.weights["availability"] * self.availability.score(healthy_services, total_services)
total += self.weights["phishing"] * self.phishing.score(phishing_detected, total_se_events)
scaled = total * self.complexity.multiplier(tier)
scaled += self.false_positive.score(false_positives)
return scaled
# -- high-level compute (called by RangeEnvironment.step) -----------
def compute(
self,
action: RangeAction,
observation: RangeObservation,
state: RangeState,
snapshot: SnapshotSpec,
ctx: dict[str, Any] | None = None,
) -> float:
"""Compute composite Blue reward from full environment context."""
ctx = ctx or {}
red_history = ctx.get("red_history", [])
blue_history = ctx.get("blue_history", [])
npc_log = ctx.get("npc_traffic_log", [])
# Detection -- coupled to Red
red_actions = [
r for r in red_history
if r.get("type") not in ("hallucinated_flag", "evidence")
]
blue_findings = [b for b in blue_history if b.get("type") == "finding"]
grounded_findings = [b for b in blue_findings if _is_grounded_finding(b, red_actions)]
tp = min(len(grounded_findings), len(red_actions))
detect_score = self.detection.score(tp, len(red_actions))
# Patch -- require grounded validation signal, not command name heuristics.
patch_validated = bool(ctx.get("patch_validated", False))
if not patch_validated:
patch_events = [
b
for b in blue_history
if b.get("type") == "patch_validation"
and b.get("result") in (True, "success", "blocked")
]
patch_validated = bool(patch_events)
patch_score = self.patch.score(patch_validated)
# Availability
svc = state.services_status or {}
total_svc = len(svc)
healthy = sum(
1 for v in svc.values()
if v in ("running", "healthy", True, "up")
)
avail_score = self.availability.score(healthy, total_svc)
# Phishing detection
se_events = [e for e in npc_log if e.get("type") == "social_engineering"]
phishing_kw = ("phish", "social", "spear", "pretext", "suspicious email")
phishing_found = [
b for b in blue_findings
if any(kw in b.get("content", "").lower() for kw in phishing_kw)
]
phish_score = self.phishing.score(len(phishing_found), len(se_events))
# False positives -- findings that reference benign NPC traffic
npc_sources = {
e.get("source", "") for e in npc_log if e.get("label") == "benign"
}
fp_count = 0
for f in blue_findings:
content = f.get("content", "").lower()
if any(src.lower() in content for src in npc_sources if src):
fp_count += 1
# Aggregate
components = {
"detection": detect_score,
"patch": patch_score,
"availability": avail_score,
"phishing": phish_score,
}
base = sum(self.weights.get(k, 0.0) * v for k, v in components.items())
scaled = base * self.complexity.multiplier(state.tier)
fp_penalty = self.false_positive.score(fp_count)
total = scaled + fp_penalty
logger.debug(
"Blue reward: components=%s base=%.3f scaled=%.3f fp=%.3f total=%.3f",
components, base, scaled, fp_penalty, total,
)
return total