File size: 3,438 Bytes
c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb 97fbf33 c0db7bb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""AML Investigator Environment Client.
High-level WebSocket client that wraps the OpenEnv EnvClient base class
with AML-specific action/observation types.
"""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import AmlAction, AmlObservation
class AmlEnv(EnvClient[AmlAction, AmlObservation, State]):
"""
WebSocket client for the AML Investigator environment.
Maintains a persistent WebSocket connection to the environment server,
enabling efficient multi-step investigations with lower per-step latency.
Example (Docker):
>>> client = AmlEnv.from_docker_image("aml-env:latest")
>>> try:
... obs = client.reset(task="aml_easy")
... result = client.step(AmlAction(action={
... "action_type": "query_transactions",
... "account_id": "ACC-9001"
... }))
... print(result.observation.last_action_result)
... finally:
... client.close()
Example (existing server):
>>> with AmlEnv(base_url="http://localhost:7860") as env:
... obs = env.reset(task="aml_easy")
... result = env.step(AmlAction(action={
... "action_type": "submit_decision",
... "decision": "CLEAR",
... "evidence_links": []
... }))
"""
def _step_payload(self, action: AmlAction) -> Dict:
"""
Serialize AmlAction to the JSON dict sent over the WebSocket.
Args:
action: Typed AmlAction wrapper containing the specific tool call.
Returns:
Dict with the nested ``action`` key the server expects.
"""
return action.model_dump()
def _parse_result(self, payload: Dict) -> StepResult[AmlObservation]:
"""
Deserialize the server's JSON response into a typed StepResult.
Args:
payload: Raw JSON response dict from the server.
Returns:
StepResult containing an AmlObservation.
"""
obs_data = payload.get("observation", {})
observation = AmlObservation(
alert_details=obs_data.get("alert_details", ""),
budget_remaining=obs_data.get("budget_remaining", 0),
last_action=obs_data.get("last_action"),
last_action_result=obs_data.get("last_action_result"),
error_message=obs_data.get("error_message"),
done=payload.get("done", False),
reward=payload.get("reward", 0.0),
)
return StepResult(
observation=observation,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
"""
Deserialize the server's /state response into a State object.
Args:
payload: Raw JSON response dict from the server.
Returns:
State with episode_id and step_count.
"""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)
|