SOC-env / server /SOC_env_environment.py
ApoorvaBadoni's picture
Upload folder using huggingface_hub
4c8f878 verified
from typing import List, Literal, Optional
from uuid import uuid4
import random
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from models import SOCAction, SOCObservation
from scenarios import EASY_SCENARIOS, HARD_SCENARIOS, MEDIUM_SCENARIOS, SCENARIOS
except ImportError:
from ..models import SOCAction, SOCObservation
from ..scenarios import EASY_SCENARIOS, HARD_SCENARIOS, MEDIUM_SCENARIOS, SCENARIOS
SCENARIO_BY_ID = {s["id"]: s for s in SCENARIOS}
Difficulty = Literal["easy", "medium", "hard", "random"]
TERMINAL_ACTIONS = {"ignore", "escalate", "patch_system"}
MAX_STEPS = {"easy": 5, "medium": 8, "hard": 12}
# Actions that reveal additional investigation context
INVESTIGATION_ACTIONS = {"investigate", "query_logs", "check_threat_intel", "run_sandbox"}
class SOCEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self, difficulty: Difficulty = "random", pinned_scenario_id: Optional[str] = None):
self.difficulty = difficulty
self._pinned_scenario_id = pinned_scenario_id
self._state = State(episode_id=str(uuid4()), step_count=0)
self._scenario = None
self._actions_taken: List[str] = []
self._investigation_done = False
self._deep_investigation_done = False
self._cumulative_score = 0.0
self._done = False
self.reset()
def reset(self) -> SOCObservation:
self._state = State(episode_id=str(uuid4()), step_count=0)
self._actions_taken = []
self._investigation_done = False
self._deep_investigation_done = False
self._cumulative_score = 0.0
self._done = False
self._scenario = self._pick_scenario()
max_steps = MAX_STEPS.get(self._scenario["difficulty"], 8)
return SOCObservation(
alert_type=self._scenario["alert_type"],
severity=self._scenario["severity"],
signals=self._scenario["initial_signals"],
context={},
available_actions=self._get_available_actions(),
phase="detection",
feedback=(
f"New Alert: {self._scenario['alert_type'].replace('_', ' ').title()}\n"
f"Severity: {self._scenario['severity'].upper()}\n"
f"Description: {self._scenario['description']}\n"
f"Tip: Use investigate, query_logs, or check_threat_intel to gather context."
),
score=0.0,
step=0,
max_steps=max_steps,
done=False,
reward=0.0,
)
def step(self, action: SOCAction) -> SOCObservation:
if self._scenario is None:
self.reset()
if self._done:
return self._terminal_obs("Episode already ended. Call reset().")
self._state.step_count += 1
decision = action.decision
max_steps = MAX_STEPS.get(self._scenario["difficulty"], 8)
reward, feedback, phase = self._evaluate(decision)
self._cumulative_score += reward
self._actions_taken.append(decision)
done = False
if decision in TERMINAL_ACTIONS:
done = True
self._done = True
elif self._state.step_count >= max_steps:
done = True
self._done = True
reward -= 0.2
self._cumulative_score -= 0.2
feedback += f" Max steps ({max_steps}) reached β€” incident unresolved."
# Build context based on investigation depth
context = self._build_context(decision)
return SOCObservation(
alert_type=self._scenario["alert_type"],
severity=self._scenario["severity"],
signals=self._scenario["initial_signals"],
context=context,
available_actions=self._get_available_actions() if not done else [],
phase=phase,
feedback=feedback,
score=round(self._cumulative_score, 2),
step=self._state.step_count,
max_steps=max_steps,
done=done,
reward=round(reward, 2),
)
def _build_context(self, decision: str) -> dict:
"""
Reveal context progressively based on investigation depth.
- First investigate/query_logs: reveals basic investigation_context
- check_threat_intel: reveals threat_intel_context if available
- run_sandbox: reveals sandbox_context if available
- Second investigate: reveals deep_investigation_context if available
"""
context = {}
scenario = self._scenario
if decision == "investigate" and not self._investigation_done:
self._investigation_done = True
context = scenario.get("investigation_context", {})
# Add a note if deeper investigation is possible
if scenario.get("deep_investigation_context"):
context["_hint"] = "More context available β€” try check_threat_intel or run_sandbox."
elif decision == "query_logs" and not self._investigation_done:
self._investigation_done = True
context = scenario.get("investigation_context", {})
context["_source"] = "SIEM log query results"
elif decision == "check_threat_intel":
ti = scenario.get("threat_intel_context", {})
if ti:
context = ti
context["_source"] = "Threat intelligence platform"
elif self._investigation_done:
context = {"_note": "No additional threat intel beyond what was already found."}
else:
context = {"_note": "Run investigate first to correlate threat intel."}
elif decision == "run_sandbox":
sb = scenario.get("sandbox_context", {})
if sb:
self._deep_investigation_done = True
context = sb
context["_source"] = "Dynamic sandbox analysis"
else:
context = {"_note": "No samples available for sandbox analysis."}
elif decision == "investigate" and self._investigation_done:
# Second investigate reveals deeper context
deep = scenario.get("deep_investigation_context", {})
if deep and not self._deep_investigation_done:
self._deep_investigation_done = True
context = deep
context["_source"] = "Deep-dive investigation"
else:
context = {"_note": "No additional context found. Consider other actions."}
return context
@property
def state(self) -> State:
return self._state
def _pick_scenario(self):
if self._pinned_scenario_id:
scenario = SCENARIO_BY_ID.get(self._pinned_scenario_id)
if scenario:
return scenario
if self.difficulty == "easy":
pool = EASY_SCENARIOS
elif self.difficulty == "medium":
pool = MEDIUM_SCENARIOS
elif self.difficulty == "hard":
pool = HARD_SCENARIOS
else:
pool = SCENARIOS
return random.choice(pool)
def _get_available_actions(self):
return [
"ignore", "monitor", "investigate", "query_logs",
"check_threat_intel", "run_sandbox", "block_ip",
"block_account", "isolate_device", "escalate",
"request_mfa", "patch_system", "collect_forensics",
]
def _evaluate(self, decision: str):
scenario = self._scenario
is_fp = scenario["false_positive"]
correct_seq = scenario["correct_sequence"]
optimal = scenario["optimal_terminal"]
if decision in self._actions_taken:
return -0.1, f"Already chose '{decision}'. Try a different approach.", "investigation"
# Investigation actions β€” always somewhat useful
if decision in INVESTIGATION_ACTIONS:
if decision == "investigate":
if not self._investigation_done:
return 0.15, "Investigation initiated. Basic context now available.", "investigation"
elif not self._deep_investigation_done and scenario.get("deep_investigation_context"):
return 0.10, "Deeper investigation complete. Additional context revealed.", "investigation"
else:
return 0.05, "No new findings from further investigation.", "investigation"
elif decision == "query_logs":
if not self._investigation_done:
return 0.15, "SIEM log query complete. Context now available.", "investigation"
else:
return 0.05, "Logs already queried. Try correlating with threat intel.", "investigation"
elif decision == "check_threat_intel":
if scenario.get("threat_intel_context"):
return 0.12, "Threat intel matched. IOCs and attribution context revealed.", "investigation"
else:
return 0.05, "No threat intel match found for these indicators.", "investigation"
elif decision == "run_sandbox":
if scenario.get("sandbox_context"):
return 0.12, "Sandbox detonation complete. Malware behavior confirmed.", "investigation"
else:
return 0.05, "Nothing to sandbox β€” no file samples available.", "investigation"
# False positive handling
if is_fp:
if decision == "ignore":
return 0.8, "Correct! This was a false positive β€” alert closed.", "closed"
elif decision in TERMINAL_ACTIONS:
return -0.3, "Over-reaction! This was a false positive β€” legitimate activity disrupted.", "closed"
else:
return 0.0, f"'{decision}' noted but has no effect on a false positive.", "monitoring"
# Real threat handling
if decision == "ignore":
return -0.5, "Dangerous! This is a real threat β€” ignoring it is a critical mistake.", "detection"
if decision == optimal and decision in TERMINAL_ACTIONS:
return 1.0, f"Perfect! '{decision}' is exactly the right call. Incident contained.", "resolved"
if decision in correct_seq:
idx = correct_seq.index(decision)
# Reward higher if earlier in sequence (correct ordering)
seq_bonus = 0.05 if idx == 0 else 0.0
return 0.3 + seq_bonus, f"Good step! Part of correct response sequence ({idx+1}/{len(correct_seq)}).", "containment"
if decision == "escalate" and scenario["severity"] in ("low", "medium"):
return -0.2, "Premature escalation on low/medium severity β€” handle at Tier-1 first.", "investigation"
if decision in TERMINAL_ACTIONS:
return -0.3, f"Wrong terminal action. Optimal response was: '{optimal}'.", "closed"
if decision == "monitor":
if scenario["severity"] in ("critical", "high"):
return -0.1, "Passive monitoring on a high/critical severity alert wastes time.", "monitoring"
return 0.05, "Monitoring in progress β€” gather more context before acting.", "monitoring"
return 0.0, f"'{decision}' noted. No significant effect on this incident.", "investigation"
def _terminal_obs(self, msg: str) -> SOCObservation:
return SOCObservation(
alert_type=self._scenario["alert_type"] if self._scenario else "",
severity="", signals=[], context={}, available_actions=[],
phase="closed", feedback=msg,
score=round(self._cumulative_score, 2),
step=self._state.step_count,
max_steps=MAX_STEPS.get(self._scenario["difficulty"], 8) if self._scenario else 8,
done=True, reward=0.0,
)