Ajayyy00 commited on
Commit
a0bdb90
Β·
1 Parent(s): baf5ea9

Add .gitignore, improve play_environment pending_followup tracking

Browse files
Files changed (2) hide show
  1. .gitignore +49 -0
  2. server/play_environment.py +129 -26
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ *.egg
11
+ .eggs/
12
+
13
+ # Virtual environments
14
+ .venv/
15
+ venv/
16
+ env/
17
+ .env
18
+
19
+ # Testing
20
+ .pytest_cache/
21
+ .coverage
22
+ htmlcov/
23
+ *.lcov
24
+
25
+ # IDEs
26
+ .vscode/
27
+ .idea/
28
+ *.swp
29
+ *.swo
30
+
31
+ # OS
32
+ .DS_Store
33
+ Thumbs.db
34
+
35
+ # Jupyter
36
+ .ipynb_checkpoints/
37
+
38
+ # uv
39
+ .uv/
40
+
41
+ # Training data / model weights (large files)
42
+ *.pt
43
+ *.pth
44
+ *.bin
45
+ *.safetensors
46
+
47
+ # Logs
48
+ *.log
49
+ logs/
server/play_environment.py CHANGED
@@ -158,6 +158,7 @@ class CyberSOCEnvironment(Environment):
158
  self._last_forensics: Optional[ForensicsResult] = None
159
  self._middleware = ActionMiddleware()
160
  self._rng = random.Random(0) # overwritten in reset()
 
161
 
162
  def _reset_rubric(self):
163
  """Initialize live containment requirements for dynamic grading in adaptive mode."""
@@ -246,6 +247,7 @@ class CyberSOCEnvironment(Environment):
246
  self._reset_rubric()
247
  self._fired_step_rewards: set = set()
248
  self._step_reward_total: float = 0.0
 
249
 
250
  # Initialize threat graph from task definition
251
  self._threat_graph = ThreatGraph()
@@ -516,6 +518,8 @@ class CyberSOCEnvironment(Environment):
516
  self._host_index[target_host]["status"] = "isolated"
517
  if self._threat_graph is not None and target_host in self._threat_graph.hosts:
518
  self._threat_graph.hosts[target_host].status = "isolated"
 
 
519
  return 0.10, f"Isolated single host '{target_host}'"
520
 
521
  subnet = action.subnet
@@ -531,6 +535,8 @@ class CyberSOCEnvironment(Environment):
531
  host["status"] = "isolated"
532
  if self._threat_graph is not None and host["hostname"] in self._threat_graph.hosts:
533
  self._threat_graph.hosts[host["hostname"]].status = "isolated"
 
 
534
 
535
  self._state.isolated_subnets.append(subnet)
536
 
@@ -571,6 +577,22 @@ class CyberSOCEnvironment(Environment):
571
 
572
  self._state.blocked_iocs.append(ioc)
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  # Check if this IOC is relevant to any active threat
575
  reward = 0.0
576
  relevant = False
@@ -641,6 +663,7 @@ class CyberSOCEnvironment(Environment):
641
  if hostname not in self._state.forensics_run:
642
  if is_compromised:
643
  reward = 0.10 # Good: found evidence
 
644
  else:
645
  reward = 0.02 # Cleared a host (some value)
646
  self._state.forensics_run.append(hostname)
@@ -689,6 +712,8 @@ class CyberSOCEnvironment(Environment):
689
  # Kill the process
690
  host["running_processes"].remove(process)
691
  self._state.killed_processes.append({"hostname": hostname, "process": process})
 
 
692
 
693
  # Check if this was a malicious process
694
  reward = 0.0
@@ -913,9 +938,9 @@ class CyberSOCEnvironment(Environment):
913
  def _compute_reward_dimensions(self) -> Dict[str, float]:
914
  """Per-step heuristic partial scores for all 10 grading dimensions.
915
 
916
- Updated every step so GRPO can assign credit without waiting for the
917
- terminal grade. Scores are in [0, 1]; the terminal grade_breakdown
918
- (from grade_episode) supersedes these once the plan is submitted.
919
  """
920
  state = self._state
921
  task_chain = self._task_def.get("attack_chain", [])
@@ -929,26 +954,104 @@ class CyberSOCEnvironment(Environment):
929
  for t in task_chain
930
  ))
931
 
932
- # 1. threat_containment β€” fraction of threats neutralised
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
  threat_containment = min(1.0, len(state.contained_threats) / total_threats)
934
 
935
- # 2. ioc_blocking β€” fraction of known IOCs blocked
936
- ioc_blocking = min(1.0, len(state.blocked_iocs) / total_iocs)
 
937
 
938
- # 3. forensic_investigation β€” fraction of compromised hosts investigated
939
- forensic_investigation = min(1.0, len(state.forensics_run) / total_compromised)
 
 
 
 
 
 
 
 
 
 
940
 
941
- # 4. siem_correlation β€” binary: did the agent correlate alerts?
942
- siem_correlation = 1.0 if state.correlated_alert_pairs else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943
 
944
- # 5. threat_intel_usage β€” fraction of IOCs enriched with threat intel
945
- threat_intel_usage = min(1.0, len(state.enriched_iocs) / total_iocs)
 
946
 
947
  # 6. vuln_root_cause β€” fraction of threats with a scanned host
948
  vuln_root_cause = min(1.0, len(state.scanned_hosts) / total_threats)
949
 
950
- # 7. business_impact β€” lower impact is better
951
- business_impact = max(0.0, 1.0 - state.business_impact)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
 
953
  # 8. step_efficiency β€” reward early resolution
954
  ratio = state.step_count / max(1, state.max_steps)
@@ -960,25 +1063,25 @@ class CyberSOCEnvironment(Environment):
960
  else:
961
  plan_coverage = min(0.5, len(state.contained_threats) / total_threats * 0.5)
962
 
963
- # 10. plan_evidence_quality β€” confidence of submitted plan; else proxy from investigation depth
964
  if state.submitted_plan and self._plan_entries:
965
  avg_conf = sum(e.get("confidence", 0.0) for e in self._plan_entries) / len(self._plan_entries)
966
  plan_evidence_quality = float(avg_conf)
967
  else:
968
- evidence_count = len(state.forensics_run) + len(state.enriched_iocs) + len(state.scanned_hosts)
969
  plan_evidence_quality = min(0.5, evidence_count / (total_compromised * 3) * 0.5)
970
 
971
  return {
972
- "threat_containment": round(threat_containment, 4),
973
- "ioc_blocking": round(ioc_blocking, 4),
974
- "forensic_investigation":round(forensic_investigation, 4),
975
- "siem_correlation": round(siem_correlation, 4),
976
- "threat_intel_usage": round(threat_intel_usage, 4),
977
- "vuln_root_cause": round(vuln_root_cause, 4),
978
- "business_impact": round(business_impact, 4),
979
- "step_efficiency": round(step_efficiency, 4),
980
- "plan_coverage": round(plan_coverage, 4),
981
- "plan_evidence_quality": round(plan_evidence_quality, 4),
982
  }
983
 
984
  def _get_current_phase(self) -> str:
 
158
  self._last_forensics: Optional[ForensicsResult] = None
159
  self._middleware = ActionMiddleware()
160
  self._rng = random.Random(0) # overwritten in reset()
161
+ self._pending_followup: Dict[str, bool] = {} # hostname -> responded_to
162
 
163
  def _reset_rubric(self):
164
  """Initialize live containment requirements for dynamic grading in adaptive mode."""
 
247
  self._reset_rubric()
248
  self._fired_step_rewards: set = set()
249
  self._step_reward_total: float = 0.0
250
+ self._pending_followup: Dict[str, bool] = {}
251
 
252
  # Initialize threat graph from task definition
253
  self._threat_graph = ThreatGraph()
 
518
  self._host_index[target_host]["status"] = "isolated"
519
  if self._threat_graph is not None and target_host in self._threat_graph.hosts:
520
  self._threat_graph.hosts[target_host].status = "isolated"
521
+ if target_host in self._pending_followup:
522
+ self._pending_followup[target_host] = True
523
  return 0.10, f"Isolated single host '{target_host}'"
524
 
525
  subnet = action.subnet
 
535
  host["status"] = "isolated"
536
  if self._threat_graph is not None and host["hostname"] in self._threat_graph.hosts:
537
  self._threat_graph.hosts[host["hostname"]].status = "isolated"
538
+ if host["hostname"] in self._pending_followup:
539
+ self._pending_followup[host["hostname"]] = True
540
 
541
  self._state.isolated_subnets.append(subnet)
542
 
 
577
 
578
  self._state.blocked_iocs.append(ioc)
579
 
580
+ # Mark any forensics-confirmed host as responded-to if this IOC belongs to its threat chain
581
+ for hostname, responded in list(self._pending_followup.items()):
582
+ if responded:
583
+ continue
584
+ for threat in self._task_def["attack_chain"]:
585
+ if hostname in threat["compromised_hosts"]:
586
+ all_threat_iocs = (
587
+ threat["iocs"].get("hashes", [])
588
+ + threat["iocs"].get("ips", [])
589
+ + threat["iocs"].get("domains", [])
590
+ + threat.get("c2_servers", [])
591
+ )
592
+ if ioc in all_threat_iocs:
593
+ self._pending_followup[hostname] = True
594
+ break
595
+
596
  # Check if this IOC is relevant to any active threat
597
  reward = 0.0
598
  relevant = False
 
663
  if hostname not in self._state.forensics_run:
664
  if is_compromised:
665
  reward = 0.10 # Good: found evidence
666
+ self._pending_followup.setdefault(hostname, False) # needs response action
667
  else:
668
  reward = 0.02 # Cleared a host (some value)
669
  self._state.forensics_run.append(hostname)
 
712
  # Kill the process
713
  host["running_processes"].remove(process)
714
  self._state.killed_processes.append({"hostname": hostname, "process": process})
715
+ if hostname in self._pending_followup:
716
+ self._pending_followup[hostname] = True
717
 
718
  # Check if this was a malicious process
719
  reward = 0.0
 
938
  def _compute_reward_dimensions(self) -> Dict[str, float]:
939
  """Per-step heuristic partial scores for all 10 grading dimensions.
940
 
941
+ Evidence-gated: actions only score if prior evidence justified them.
942
+ Result-usage: forensics-confirmed hosts with no followup are penalized.
943
+ Scores in [0, 1]; terminal grade_breakdown supersedes these on plan submission.
944
  """
945
  state = self._state
946
  task_chain = self._task_def.get("attack_chain", [])
 
954
  for t in task_chain
955
  ))
956
 
957
+ # --- Build evidence pools: what the agent could have observed ---
958
+ # Hosts mentioned as alert source (visible from turn 0)
959
+ alert_source_hosts: set = set()
960
+ for a in self._task_def.get("initial_alerts", []):
961
+ alert_source_hosts.add(a.get("source_host", ""))
962
+ for a in self._alert_queue:
963
+ alert_source_hosts.add(a.get("source_host", ""))
964
+ alert_source_hosts.discard("")
965
+
966
+ # IOCs visible from alert ioc_indicators
967
+ alert_iocs: set = set()
968
+ for a_list in (self._task_def.get("initial_alerts", []), self._alert_queue):
969
+ for a in a_list:
970
+ for ioc in a.get("ioc_indicators", []):
971
+ alert_iocs.add(ioc)
972
+
973
+ # IOCs revealed by running forensics on a host
974
+ forensics_revealed_iocs: set = set()
975
+ for hostname in state.forensics_run:
976
+ for threat in task_chain:
977
+ if hostname in threat.get("compromised_hosts", []):
978
+ forensics_revealed_iocs.update(threat.get("c2_servers", []))
979
+ forensics_revealed_iocs.update(threat["iocs"].get("hashes", []))
980
+ forensics_revealed_iocs.update(threat["iocs"].get("ips", []))
981
+ forensics_revealed_iocs.update(threat["iocs"].get("domains", []))
982
+
983
+ discovered_iocs = alert_iocs | forensics_revealed_iocs
984
+
985
+ # 1. threat_containment β€” fraction of threats neutralised (no evidence gate; outcome IS evidence)
986
  threat_containment = min(1.0, len(state.contained_threats) / total_threats)
987
 
988
+ # 2. ioc_blocking β€” only blocks of IOCs the agent actually discovered count
989
+ justified_blocks = [ioc for ioc in state.blocked_iocs if ioc in discovered_iocs]
990
+ ioc_blocking = min(1.0, len(justified_blocks) / total_iocs)
991
 
992
+ # 3. forensic_investigation β€” only counts forensics on alert-mentioned or previously queried
993
+ # hosts; penalizes confirmed compromises left with no response action
994
+ justified_forensics = [
995
+ h for h in state.forensics_run
996
+ if h in alert_source_hosts or h in state.queried_hosts
997
+ ]
998
+ pending = self._pending_followup
999
+ unresponded = sum(1 for v in pending.values() if not v)
1000
+ followup_penalty = min(0.30, unresponded * 0.10)
1001
+ forensic_investigation = max(0.0,
1002
+ min(1.0, len(justified_forensics) / total_compromised) - followup_penalty
1003
+ )
1004
 
1005
+ # 4. siem_correlation β€” scored by semantic quality (shared source hosts or IOCs)
1006
+ if not state.correlated_alert_pairs:
1007
+ siem_correlation = 0.0
1008
+ else:
1009
+ alert_map: Dict[str, Any] = {}
1010
+ for a in self._task_def.get("initial_alerts", []):
1011
+ alert_map[a.get("alert_id", "")] = a
1012
+ for a in self._alert_queue:
1013
+ alert_map[a.get("alert_id", "")] = a
1014
+ quality_scores = []
1015
+ for pair in state.correlated_alert_pairs:
1016
+ pair_alerts = [alert_map[aid] for aid in pair if aid in alert_map]
1017
+ if len(pair_alerts) < 2:
1018
+ quality_scores.append(0.3)
1019
+ continue
1020
+ sources = [a.get("source_host") for a in pair_alerts]
1021
+ ioc_sets = [set(a.get("ioc_indicators", [])) for a in pair_alerts]
1022
+ shared_hosts = len(sources) != len({s for s in sources if s})
1023
+ shared_iocs = bool(ioc_sets[0] & ioc_sets[1]) if len(ioc_sets) >= 2 else False
1024
+ quality_scores.append(1.0 if (shared_hosts or shared_iocs) else 0.2)
1025
+ siem_correlation = sum(quality_scores) / max(1, len(quality_scores))
1026
 
1027
+ # 5. threat_intel_usage β€” only enrichments of discovered IOCs count
1028
+ justified_enrichments = [ioc for ioc in state.enriched_iocs if ioc in discovered_iocs]
1029
+ threat_intel_usage = min(1.0, len(justified_enrichments) / total_iocs)
1030
 
1031
  # 6. vuln_root_cause β€” fraction of threats with a scanned host
1032
  vuln_root_cause = min(1.0, len(state.scanned_hosts) / total_threats)
1033
 
1034
+ # 7. business_impact β€” proportionate isolation + low overall impact
1035
+ # Reward: isolating confirmed-compromised hosts Penalize: isolating clean hosts
1036
+ isolated_host_set = {
1037
+ h for h, hd in self._host_index.items() if hd.get("status") == "isolated"
1038
+ } if self._host_index else set()
1039
+ compromised_host_set = {
1040
+ h for threat in task_chain for h in threat.get("compromised_hosts", [])
1041
+ }
1042
+ if isolated_host_set:
1043
+ over_isolated = isolated_host_set - compromised_host_set
1044
+ isolation_proportion = (
1045
+ len(isolated_host_set - over_isolated) / len(isolated_host_set)
1046
+ )
1047
+ over_iso_penalty = min(0.40, len(over_isolated) * 0.15)
1048
+ else:
1049
+ isolation_proportion = 1.0
1050
+ over_iso_penalty = 0.0
1051
+ raw_impact_score = max(0.0, 1.0 - state.business_impact)
1052
+ business_impact = max(0.0, min(1.0,
1053
+ 0.6 * raw_impact_score + 0.4 * isolation_proportion - over_iso_penalty
1054
+ ))
1055
 
1056
  # 8. step_efficiency β€” reward early resolution
1057
  ratio = state.step_count / max(1, state.max_steps)
 
1063
  else:
1064
  plan_coverage = min(0.5, len(state.contained_threats) / total_threats * 0.5)
1065
 
1066
+ # 10. plan_evidence_quality β€” confidence of submitted plan; else evidence depth proxy
1067
  if state.submitted_plan and self._plan_entries:
1068
  avg_conf = sum(e.get("confidence", 0.0) for e in self._plan_entries) / len(self._plan_entries)
1069
  plan_evidence_quality = float(avg_conf)
1070
  else:
1071
+ evidence_count = len(justified_forensics) + len(justified_enrichments) + len(state.scanned_hosts)
1072
  plan_evidence_quality = min(0.5, evidence_count / (total_compromised * 3) * 0.5)
1073
 
1074
  return {
1075
+ "threat_containment": round(threat_containment, 4),
1076
+ "ioc_blocking": round(ioc_blocking, 4),
1077
+ "forensic_investigation": round(forensic_investigation, 4),
1078
+ "siem_correlation": round(siem_correlation, 4),
1079
+ "threat_intel_usage": round(threat_intel_usage, 4),
1080
+ "vuln_root_cause": round(vuln_root_cause, 4),
1081
+ "business_impact": round(business_impact, 4),
1082
+ "step_efficiency": round(step_efficiency, 4),
1083
+ "plan_coverage": round(plan_coverage, 4),
1084
+ "plan_evidence_quality": round(plan_evidence_quality, 4),
1085
  }
1086
 
1087
  def _get_current_phase(self) -> str: