Ajayyy00 commited on
Commit Β·
a0bdb90
1
Parent(s): baf5ea9
Add .gitignore, improve play_environment pending_followup tracking
Browse files- .gitignore +49 -0
- server/play_environment.py +129 -26
.gitignore
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
*.egg
|
| 11 |
+
.eggs/
|
| 12 |
+
|
| 13 |
+
# Virtual environments
|
| 14 |
+
.venv/
|
| 15 |
+
venv/
|
| 16 |
+
env/
|
| 17 |
+
.env
|
| 18 |
+
|
| 19 |
+
# Testing
|
| 20 |
+
.pytest_cache/
|
| 21 |
+
.coverage
|
| 22 |
+
htmlcov/
|
| 23 |
+
*.lcov
|
| 24 |
+
|
| 25 |
+
# IDEs
|
| 26 |
+
.vscode/
|
| 27 |
+
.idea/
|
| 28 |
+
*.swp
|
| 29 |
+
*.swo
|
| 30 |
+
|
| 31 |
+
# OS
|
| 32 |
+
.DS_Store
|
| 33 |
+
Thumbs.db
|
| 34 |
+
|
| 35 |
+
# Jupyter
|
| 36 |
+
.ipynb_checkpoints/
|
| 37 |
+
|
| 38 |
+
# uv
|
| 39 |
+
.uv/
|
| 40 |
+
|
| 41 |
+
# Training data / model weights (large files)
|
| 42 |
+
*.pt
|
| 43 |
+
*.pth
|
| 44 |
+
*.bin
|
| 45 |
+
*.safetensors
|
| 46 |
+
|
| 47 |
+
# Logs
|
| 48 |
+
*.log
|
| 49 |
+
logs/
|
server/play_environment.py
CHANGED
|
@@ -158,6 +158,7 @@ class CyberSOCEnvironment(Environment):
|
|
| 158 |
self._last_forensics: Optional[ForensicsResult] = None
|
| 159 |
self._middleware = ActionMiddleware()
|
| 160 |
self._rng = random.Random(0) # overwritten in reset()
|
|
|
|
| 161 |
|
| 162 |
def _reset_rubric(self):
|
| 163 |
"""Initialize live containment requirements for dynamic grading in adaptive mode."""
|
|
@@ -246,6 +247,7 @@ class CyberSOCEnvironment(Environment):
|
|
| 246 |
self._reset_rubric()
|
| 247 |
self._fired_step_rewards: set = set()
|
| 248 |
self._step_reward_total: float = 0.0
|
|
|
|
| 249 |
|
| 250 |
# Initialize threat graph from task definition
|
| 251 |
self._threat_graph = ThreatGraph()
|
|
@@ -516,6 +518,8 @@ class CyberSOCEnvironment(Environment):
|
|
| 516 |
self._host_index[target_host]["status"] = "isolated"
|
| 517 |
if self._threat_graph is not None and target_host in self._threat_graph.hosts:
|
| 518 |
self._threat_graph.hosts[target_host].status = "isolated"
|
|
|
|
|
|
|
| 519 |
return 0.10, f"Isolated single host '{target_host}'"
|
| 520 |
|
| 521 |
subnet = action.subnet
|
|
@@ -531,6 +535,8 @@ class CyberSOCEnvironment(Environment):
|
|
| 531 |
host["status"] = "isolated"
|
| 532 |
if self._threat_graph is not None and host["hostname"] in self._threat_graph.hosts:
|
| 533 |
self._threat_graph.hosts[host["hostname"]].status = "isolated"
|
|
|
|
|
|
|
| 534 |
|
| 535 |
self._state.isolated_subnets.append(subnet)
|
| 536 |
|
|
@@ -571,6 +577,22 @@ class CyberSOCEnvironment(Environment):
|
|
| 571 |
|
| 572 |
self._state.blocked_iocs.append(ioc)
|
| 573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
# Check if this IOC is relevant to any active threat
|
| 575 |
reward = 0.0
|
| 576 |
relevant = False
|
|
@@ -641,6 +663,7 @@ class CyberSOCEnvironment(Environment):
|
|
| 641 |
if hostname not in self._state.forensics_run:
|
| 642 |
if is_compromised:
|
| 643 |
reward = 0.10 # Good: found evidence
|
|
|
|
| 644 |
else:
|
| 645 |
reward = 0.02 # Cleared a host (some value)
|
| 646 |
self._state.forensics_run.append(hostname)
|
|
@@ -689,6 +712,8 @@ class CyberSOCEnvironment(Environment):
|
|
| 689 |
# Kill the process
|
| 690 |
host["running_processes"].remove(process)
|
| 691 |
self._state.killed_processes.append({"hostname": hostname, "process": process})
|
|
|
|
|
|
|
| 692 |
|
| 693 |
# Check if this was a malicious process
|
| 694 |
reward = 0.0
|
|
@@ -913,9 +938,9 @@ class CyberSOCEnvironment(Environment):
|
|
| 913 |
def _compute_reward_dimensions(self) -> Dict[str, float]:
|
| 914 |
"""Per-step heuristic partial scores for all 10 grading dimensions.
|
| 915 |
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
"""
|
| 920 |
state = self._state
|
| 921 |
task_chain = self._task_def.get("attack_chain", [])
|
|
@@ -929,26 +954,104 @@ class CyberSOCEnvironment(Environment):
|
|
| 929 |
for t in task_chain
|
| 930 |
))
|
| 931 |
|
| 932 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 933 |
threat_containment = min(1.0, len(state.contained_threats) / total_threats)
|
| 934 |
|
| 935 |
-
# 2. ioc_blocking β
|
| 936 |
-
|
|
|
|
| 937 |
|
| 938 |
-
# 3. forensic_investigation β
|
| 939 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 940 |
|
| 941 |
-
# 4. siem_correlation β
|
| 942 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
|
| 944 |
-
# 5. threat_intel_usage β
|
| 945 |
-
|
|
|
|
| 946 |
|
| 947 |
# 6. vuln_root_cause β fraction of threats with a scanned host
|
| 948 |
vuln_root_cause = min(1.0, len(state.scanned_hosts) / total_threats)
|
| 949 |
|
| 950 |
-
# 7. business_impact β
|
| 951 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
|
| 953 |
# 8. step_efficiency β reward early resolution
|
| 954 |
ratio = state.step_count / max(1, state.max_steps)
|
|
@@ -960,25 +1063,25 @@ class CyberSOCEnvironment(Environment):
|
|
| 960 |
else:
|
| 961 |
plan_coverage = min(0.5, len(state.contained_threats) / total_threats * 0.5)
|
| 962 |
|
| 963 |
-
# 10. plan_evidence_quality β confidence of submitted plan; else
|
| 964 |
if state.submitted_plan and self._plan_entries:
|
| 965 |
avg_conf = sum(e.get("confidence", 0.0) for e in self._plan_entries) / len(self._plan_entries)
|
| 966 |
plan_evidence_quality = float(avg_conf)
|
| 967 |
else:
|
| 968 |
-
evidence_count = len(
|
| 969 |
plan_evidence_quality = min(0.5, evidence_count / (total_compromised * 3) * 0.5)
|
| 970 |
|
| 971 |
return {
|
| 972 |
-
"threat_containment":
|
| 973 |
-
"ioc_blocking":
|
| 974 |
-
"forensic_investigation":round(forensic_investigation, 4),
|
| 975 |
-
"siem_correlation":
|
| 976 |
-
"threat_intel_usage":
|
| 977 |
-
"vuln_root_cause":
|
| 978 |
-
"business_impact":
|
| 979 |
-
"step_efficiency":
|
| 980 |
-
"plan_coverage":
|
| 981 |
-
"plan_evidence_quality":
|
| 982 |
}
|
| 983 |
|
| 984 |
def _get_current_phase(self) -> str:
|
|
|
|
| 158 |
self._last_forensics: Optional[ForensicsResult] = None
|
| 159 |
self._middleware = ActionMiddleware()
|
| 160 |
self._rng = random.Random(0) # overwritten in reset()
|
| 161 |
+
self._pending_followup: Dict[str, bool] = {} # hostname -> responded_to
|
| 162 |
|
| 163 |
def _reset_rubric(self):
|
| 164 |
"""Initialize live containment requirements for dynamic grading in adaptive mode."""
|
|
|
|
| 247 |
self._reset_rubric()
|
| 248 |
self._fired_step_rewards: set = set()
|
| 249 |
self._step_reward_total: float = 0.0
|
| 250 |
+
self._pending_followup: Dict[str, bool] = {}
|
| 251 |
|
| 252 |
# Initialize threat graph from task definition
|
| 253 |
self._threat_graph = ThreatGraph()
|
|
|
|
| 518 |
self._host_index[target_host]["status"] = "isolated"
|
| 519 |
if self._threat_graph is not None and target_host in self._threat_graph.hosts:
|
| 520 |
self._threat_graph.hosts[target_host].status = "isolated"
|
| 521 |
+
if target_host in self._pending_followup:
|
| 522 |
+
self._pending_followup[target_host] = True
|
| 523 |
return 0.10, f"Isolated single host '{target_host}'"
|
| 524 |
|
| 525 |
subnet = action.subnet
|
|
|
|
| 535 |
host["status"] = "isolated"
|
| 536 |
if self._threat_graph is not None and host["hostname"] in self._threat_graph.hosts:
|
| 537 |
self._threat_graph.hosts[host["hostname"]].status = "isolated"
|
| 538 |
+
if host["hostname"] in self._pending_followup:
|
| 539 |
+
self._pending_followup[host["hostname"]] = True
|
| 540 |
|
| 541 |
self._state.isolated_subnets.append(subnet)
|
| 542 |
|
|
|
|
| 577 |
|
| 578 |
self._state.blocked_iocs.append(ioc)
|
| 579 |
|
| 580 |
+
# Mark any forensics-confirmed host as responded-to if this IOC belongs to its threat chain
|
| 581 |
+
for hostname, responded in list(self._pending_followup.items()):
|
| 582 |
+
if responded:
|
| 583 |
+
continue
|
| 584 |
+
for threat in self._task_def["attack_chain"]:
|
| 585 |
+
if hostname in threat["compromised_hosts"]:
|
| 586 |
+
all_threat_iocs = (
|
| 587 |
+
threat["iocs"].get("hashes", [])
|
| 588 |
+
+ threat["iocs"].get("ips", [])
|
| 589 |
+
+ threat["iocs"].get("domains", [])
|
| 590 |
+
+ threat.get("c2_servers", [])
|
| 591 |
+
)
|
| 592 |
+
if ioc in all_threat_iocs:
|
| 593 |
+
self._pending_followup[hostname] = True
|
| 594 |
+
break
|
| 595 |
+
|
| 596 |
# Check if this IOC is relevant to any active threat
|
| 597 |
reward = 0.0
|
| 598 |
relevant = False
|
|
|
|
| 663 |
if hostname not in self._state.forensics_run:
|
| 664 |
if is_compromised:
|
| 665 |
reward = 0.10 # Good: found evidence
|
| 666 |
+
self._pending_followup.setdefault(hostname, False) # needs response action
|
| 667 |
else:
|
| 668 |
reward = 0.02 # Cleared a host (some value)
|
| 669 |
self._state.forensics_run.append(hostname)
|
|
|
|
| 712 |
# Kill the process
|
| 713 |
host["running_processes"].remove(process)
|
| 714 |
self._state.killed_processes.append({"hostname": hostname, "process": process})
|
| 715 |
+
if hostname in self._pending_followup:
|
| 716 |
+
self._pending_followup[hostname] = True
|
| 717 |
|
| 718 |
# Check if this was a malicious process
|
| 719 |
reward = 0.0
|
|
|
|
| 938 |
def _compute_reward_dimensions(self) -> Dict[str, float]:
|
| 939 |
"""Per-step heuristic partial scores for all 10 grading dimensions.
|
| 940 |
|
| 941 |
+
Evidence-gated: actions only score if prior evidence justified them.
|
| 942 |
+
Result-usage: forensics-confirmed hosts with no followup are penalized.
|
| 943 |
+
Scores in [0, 1]; terminal grade_breakdown supersedes these on plan submission.
|
| 944 |
"""
|
| 945 |
state = self._state
|
| 946 |
task_chain = self._task_def.get("attack_chain", [])
|
|
|
|
| 954 |
for t in task_chain
|
| 955 |
))
|
| 956 |
|
| 957 |
+
# --- Build evidence pools: what the agent could have observed ---
|
| 958 |
+
# Hosts mentioned as alert source (visible from turn 0)
|
| 959 |
+
alert_source_hosts: set = set()
|
| 960 |
+
for a in self._task_def.get("initial_alerts", []):
|
| 961 |
+
alert_source_hosts.add(a.get("source_host", ""))
|
| 962 |
+
for a in self._alert_queue:
|
| 963 |
+
alert_source_hosts.add(a.get("source_host", ""))
|
| 964 |
+
alert_source_hosts.discard("")
|
| 965 |
+
|
| 966 |
+
# IOCs visible from alert ioc_indicators
|
| 967 |
+
alert_iocs: set = set()
|
| 968 |
+
for a_list in (self._task_def.get("initial_alerts", []), self._alert_queue):
|
| 969 |
+
for a in a_list:
|
| 970 |
+
for ioc in a.get("ioc_indicators", []):
|
| 971 |
+
alert_iocs.add(ioc)
|
| 972 |
+
|
| 973 |
+
# IOCs revealed by running forensics on a host
|
| 974 |
+
forensics_revealed_iocs: set = set()
|
| 975 |
+
for hostname in state.forensics_run:
|
| 976 |
+
for threat in task_chain:
|
| 977 |
+
if hostname in threat.get("compromised_hosts", []):
|
| 978 |
+
forensics_revealed_iocs.update(threat.get("c2_servers", []))
|
| 979 |
+
forensics_revealed_iocs.update(threat["iocs"].get("hashes", []))
|
| 980 |
+
forensics_revealed_iocs.update(threat["iocs"].get("ips", []))
|
| 981 |
+
forensics_revealed_iocs.update(threat["iocs"].get("domains", []))
|
| 982 |
+
|
| 983 |
+
discovered_iocs = alert_iocs | forensics_revealed_iocs
|
| 984 |
+
|
| 985 |
+
# 1. threat_containment β fraction of threats neutralised (no evidence gate; outcome IS evidence)
|
| 986 |
threat_containment = min(1.0, len(state.contained_threats) / total_threats)
|
| 987 |
|
| 988 |
+
# 2. ioc_blocking β only blocks of IOCs the agent actually discovered count
|
| 989 |
+
justified_blocks = [ioc for ioc in state.blocked_iocs if ioc in discovered_iocs]
|
| 990 |
+
ioc_blocking = min(1.0, len(justified_blocks) / total_iocs)
|
| 991 |
|
| 992 |
+
# 3. forensic_investigation β only counts forensics on alert-mentioned or previously queried
|
| 993 |
+
# hosts; penalizes confirmed compromises left with no response action
|
| 994 |
+
justified_forensics = [
|
| 995 |
+
h for h in state.forensics_run
|
| 996 |
+
if h in alert_source_hosts or h in state.queried_hosts
|
| 997 |
+
]
|
| 998 |
+
pending = self._pending_followup
|
| 999 |
+
unresponded = sum(1 for v in pending.values() if not v)
|
| 1000 |
+
followup_penalty = min(0.30, unresponded * 0.10)
|
| 1001 |
+
forensic_investigation = max(0.0,
|
| 1002 |
+
min(1.0, len(justified_forensics) / total_compromised) - followup_penalty
|
| 1003 |
+
)
|
| 1004 |
|
| 1005 |
+
# 4. siem_correlation β scored by semantic quality (shared source hosts or IOCs)
|
| 1006 |
+
if not state.correlated_alert_pairs:
|
| 1007 |
+
siem_correlation = 0.0
|
| 1008 |
+
else:
|
| 1009 |
+
alert_map: Dict[str, Any] = {}
|
| 1010 |
+
for a in self._task_def.get("initial_alerts", []):
|
| 1011 |
+
alert_map[a.get("alert_id", "")] = a
|
| 1012 |
+
for a in self._alert_queue:
|
| 1013 |
+
alert_map[a.get("alert_id", "")] = a
|
| 1014 |
+
quality_scores = []
|
| 1015 |
+
for pair in state.correlated_alert_pairs:
|
| 1016 |
+
pair_alerts = [alert_map[aid] for aid in pair if aid in alert_map]
|
| 1017 |
+
if len(pair_alerts) < 2:
|
| 1018 |
+
quality_scores.append(0.3)
|
| 1019 |
+
continue
|
| 1020 |
+
sources = [a.get("source_host") for a in pair_alerts]
|
| 1021 |
+
ioc_sets = [set(a.get("ioc_indicators", [])) for a in pair_alerts]
|
| 1022 |
+
shared_hosts = len(sources) != len({s for s in sources if s})
|
| 1023 |
+
shared_iocs = bool(ioc_sets[0] & ioc_sets[1]) if len(ioc_sets) >= 2 else False
|
| 1024 |
+
quality_scores.append(1.0 if (shared_hosts or shared_iocs) else 0.2)
|
| 1025 |
+
siem_correlation = sum(quality_scores) / max(1, len(quality_scores))
|
| 1026 |
|
| 1027 |
+
# 5. threat_intel_usage β only enrichments of discovered IOCs count
|
| 1028 |
+
justified_enrichments = [ioc for ioc in state.enriched_iocs if ioc in discovered_iocs]
|
| 1029 |
+
threat_intel_usage = min(1.0, len(justified_enrichments) / total_iocs)
|
| 1030 |
|
| 1031 |
# 6. vuln_root_cause β fraction of threats with a scanned host
|
| 1032 |
vuln_root_cause = min(1.0, len(state.scanned_hosts) / total_threats)
|
| 1033 |
|
| 1034 |
+
# 7. business_impact β proportionate isolation + low overall impact
|
| 1035 |
+
# Reward: isolating confirmed-compromised hosts Penalize: isolating clean hosts
|
| 1036 |
+
isolated_host_set = {
|
| 1037 |
+
h for h, hd in self._host_index.items() if hd.get("status") == "isolated"
|
| 1038 |
+
} if self._host_index else set()
|
| 1039 |
+
compromised_host_set = {
|
| 1040 |
+
h for threat in task_chain for h in threat.get("compromised_hosts", [])
|
| 1041 |
+
}
|
| 1042 |
+
if isolated_host_set:
|
| 1043 |
+
over_isolated = isolated_host_set - compromised_host_set
|
| 1044 |
+
isolation_proportion = (
|
| 1045 |
+
len(isolated_host_set - over_isolated) / len(isolated_host_set)
|
| 1046 |
+
)
|
| 1047 |
+
over_iso_penalty = min(0.40, len(over_isolated) * 0.15)
|
| 1048 |
+
else:
|
| 1049 |
+
isolation_proportion = 1.0
|
| 1050 |
+
over_iso_penalty = 0.0
|
| 1051 |
+
raw_impact_score = max(0.0, 1.0 - state.business_impact)
|
| 1052 |
+
business_impact = max(0.0, min(1.0,
|
| 1053 |
+
0.6 * raw_impact_score + 0.4 * isolation_proportion - over_iso_penalty
|
| 1054 |
+
))
|
| 1055 |
|
| 1056 |
# 8. step_efficiency β reward early resolution
|
| 1057 |
ratio = state.step_count / max(1, state.max_steps)
|
|
|
|
| 1063 |
else:
|
| 1064 |
plan_coverage = min(0.5, len(state.contained_threats) / total_threats * 0.5)
|
| 1065 |
|
| 1066 |
+
# 10. plan_evidence_quality β confidence of submitted plan; else evidence depth proxy
|
| 1067 |
if state.submitted_plan and self._plan_entries:
|
| 1068 |
avg_conf = sum(e.get("confidence", 0.0) for e in self._plan_entries) / len(self._plan_entries)
|
| 1069 |
plan_evidence_quality = float(avg_conf)
|
| 1070 |
else:
|
| 1071 |
+
evidence_count = len(justified_forensics) + len(justified_enrichments) + len(state.scanned_hosts)
|
| 1072 |
plan_evidence_quality = min(0.5, evidence_count / (total_compromised * 3) * 0.5)
|
| 1073 |
|
| 1074 |
return {
|
| 1075 |
+
"threat_containment": round(threat_containment, 4),
|
| 1076 |
+
"ioc_blocking": round(ioc_blocking, 4),
|
| 1077 |
+
"forensic_investigation": round(forensic_investigation, 4),
|
| 1078 |
+
"siem_correlation": round(siem_correlation, 4),
|
| 1079 |
+
"threat_intel_usage": round(threat_intel_usage, 4),
|
| 1080 |
+
"vuln_root_cause": round(vuln_root_cause, 4),
|
| 1081 |
+
"business_impact": round(business_impact, 4),
|
| 1082 |
+
"step_efficiency": round(step_efficiency, 4),
|
| 1083 |
+
"plan_coverage": round(plan_coverage, 4),
|
| 1084 |
+
"plan_evidence_quality": round(plan_evidence_quality, 4),
|
| 1085 |
}
|
| 1086 |
|
| 1087 |
def _get_current_phase(self) -> str:
|