security-audit-env / client.py
anshumanatrey's picture
Upload folder using huggingface_hub
2b85191 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Security Audit Environment Client."""
from typing import Any, Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from .models import SecurityAuditAction, SecurityAuditObservation, SecurityAuditState
class SecurityAuditEnv(
EnvClient[SecurityAuditAction, SecurityAuditObservation, SecurityAuditState]
):
"""
Client for the Security Audit Environment.
Example:
>>> with SecurityAuditEnv(base_url="http://localhost:8000").sync() as env:
... result = env.reset(scenario_id="easy")
... print(result.observation.message)
...
... result = env.step(SecurityAuditAction(
... action_type="list_tools"
... ))
... print(result.observation.tool_output)
"""
def _step_payload(self, action: SecurityAuditAction) -> Dict[str, Any]:
return action.model_dump(exclude_none=True)
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SecurityAuditObservation]:
obs_data = payload.get("observation", {})
observation = SecurityAuditObservation(
tool_output=obs_data.get("tool_output", ""),
available_tools=obs_data.get("available_tools"),
discovered_hosts=obs_data.get("discovered_hosts", []),
discovered_services=obs_data.get("discovered_services", {}),
findings_submitted=obs_data.get("findings_submitted", 0),
steps_remaining=obs_data.get("steps_remaining", 0),
message=obs_data.get("message", ""),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> SecurityAuditState:
return SecurityAuditState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
scenario_id=payload.get("scenario_id", ""),
scenario_name=payload.get("scenario_name", ""),
target_network=payload.get("target_network", ""),
max_steps=payload.get("max_steps", 50),
discovered_hosts=payload.get("discovered_hosts", []),
discovered_ports=payload.get("discovered_ports", {}),
discovered_services=payload.get("discovered_services", {}),
submitted_findings=payload.get("submitted_findings", []),
total_reward=payload.get("total_reward", 0.0),
)