| |
| |
| |
| |
| |
|
|
| """CyberSOCEnv Client — connects to the SOC environment server.""" |
|
|
| from typing import Dict |
|
|
| from openenv.core import EnvClient |
| from openenv.core.client_types import StepResult |
|
|
| from .models import ( |
| SOCObservation, |
| SOCActionWrapper, |
| SOCState, |
| Alert, |
| Severity, |
| ThreatType, |
| NetworkTopology, |
| ForensicsResult, |
| TimelineEntry, |
| ) |
|
|
|
|
| class CyberSOCClient( |
| EnvClient[SOCActionWrapper, SOCObservation, SOCState] |
| ): |
| """ |
| Client for the CyberSOCEnv environment. |
| |
| Connects via WebSocket to the SOC environment server for |
| low-latency, persistent-session interaction. |
| |
| Example: |
| >>> with CyberSOCClient(base_url="http://localhost:8000") as client: |
| ... result = client.reset() |
| ... print(result.observation.alert_queue) |
| ... |
| ... from play.models import QueryHost |
| ... result = client.step(SOCActionWrapper(type="query_host", hostname="WS-001")) |
| ... print(result.observation.host_forensics) |
| """ |
|
|
| def _step_payload(self, action: SOCActionWrapper) -> Dict: |
| """Convert SOCActionWrapper to JSON payload for step message.""" |
| return action.model_dump(exclude_none=True) |
|
|
| def _parse_result(self, payload: Dict) -> StepResult[SOCObservation]: |
| """Parse server response into StepResult[SOCObservation].""" |
| obs_data = payload.get("observation", {}) |
|
|
| |
| alerts = [Alert(**a) for a in obs_data.get("alert_queue", [])] |
|
|
| |
| topo_data = obs_data.get("network_topology", {}) |
| topology = NetworkTopology(**topo_data) if topo_data else NetworkTopology() |
|
|
| |
| forensics_data = obs_data.get("host_forensics") |
| forensics = ForensicsResult(**forensics_data) if forensics_data else None |
|
|
| |
| timeline = [TimelineEntry(**t) for t in obs_data.get("timeline", [])] |
|
|
| observation = SOCObservation( |
| episode_id=obs_data.get("episode_id", ""), |
| alert_queue=alerts, |
| network_topology=topology, |
| host_forensics=forensics, |
| timeline=timeline, |
| business_impact_score=obs_data.get("business_impact_score", 0.0), |
| step_count=obs_data.get("step_count", 0), |
| active_threats=obs_data.get("active_threats", []), |
| max_steps=obs_data.get("max_steps", 30), |
| task_id=obs_data.get("task_id", "easy"), |
| total_reward=obs_data.get("total_reward", 0.0), |
| final_score=obs_data.get("final_score"), |
| grade_breakdown=obs_data.get("grade_breakdown"), |
| done=payload.get("done", False), |
| reward=payload.get("reward"), |
| ) |
|
|
| result = StepResult( |
| observation=observation, |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| ) |
| |
| result.episode_id = observation.episode_id |
| return result |
|
|
| def _parse_state(self, payload: Dict) -> SOCState: |
| """Parse server response into SOCState.""" |
| return SOCState( |
| episode_id=payload.get("episode_id"), |
| step_count=payload.get("step_count", 0), |
| task_id=payload.get("task_id", "easy"), |
| total_reward=payload.get("total_reward", 0.0), |
| business_impact=payload.get("business_impact", 0.0), |
| ) |
|
|