Demo / client.py
Ajayyy00
Initial commit of CyberSOC upgraded RLVR environment
57e71f8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""CyberSOCEnv Client — connects to the SOC environment server."""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from .models import (
SOCObservation,
SOCActionWrapper,
SOCState,
Alert,
Severity,
ThreatType,
NetworkTopology,
ForensicsResult,
TimelineEntry,
)
class CyberSOCClient(
EnvClient[SOCActionWrapper, SOCObservation, SOCState]
):
"""
Client for the CyberSOCEnv environment.
Connects via WebSocket to the SOC environment server for
low-latency, persistent-session interaction.
Example:
>>> with CyberSOCClient(base_url="http://localhost:8000") as client:
... result = client.reset()
... print(result.observation.alert_queue)
...
... from play.models import QueryHost
... result = client.step(SOCActionWrapper(type="query_host", hostname="WS-001"))
... print(result.observation.host_forensics)
"""
def _step_payload(self, action: SOCActionWrapper) -> Dict:
"""Convert SOCActionWrapper to JSON payload for step message."""
return action.model_dump(exclude_none=True)
def _parse_result(self, payload: Dict) -> StepResult[SOCObservation]:
"""Parse server response into StepResult[SOCObservation]."""
obs_data = payload.get("observation", {})
# Parse alerts
alerts = [Alert(**a) for a in obs_data.get("alert_queue", [])]
# Parse network topology
topo_data = obs_data.get("network_topology", {})
topology = NetworkTopology(**topo_data) if topo_data else NetworkTopology()
# Parse forensics (may be None)
forensics_data = obs_data.get("host_forensics")
forensics = ForensicsResult(**forensics_data) if forensics_data else None
# Parse timeline
timeline = [TimelineEntry(**t) for t in obs_data.get("timeline", [])]
observation = SOCObservation(
episode_id=obs_data.get("episode_id", ""),
alert_queue=alerts,
network_topology=topology,
host_forensics=forensics,
timeline=timeline,
business_impact_score=obs_data.get("business_impact_score", 0.0),
step_count=obs_data.get("step_count", 0),
active_threats=obs_data.get("active_threats", []),
max_steps=obs_data.get("max_steps", 30),
task_id=obs_data.get("task_id", "easy"),
total_reward=obs_data.get("total_reward", 0.0),
final_score=obs_data.get("final_score"),
grade_breakdown=obs_data.get("grade_breakdown"),
done=payload.get("done", False),
reward=payload.get("reward"),
)
result = StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
# Attach episode_id directly on the result for easy RL loop access
result.episode_id = observation.episode_id # type: ignore[attr-defined]
return result
def _parse_state(self, payload: Dict) -> SOCState:
"""Parse server response into SOCState."""
return SOCState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
task_id=payload.get("task_id", "easy"),
total_reward=payload.get("total_reward", 0.0),
business_impact=payload.get("business_impact", 0.0),
)