# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Constraint Env Environment Client.""" from typing import Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from .models import ConstraintAction, ConstraintObservation, ConstraintState class ConstraintEnv( EnvClient[ConstraintAction, ConstraintObservation, ConstraintState] ): """ Client for the Constraint Env Environment. This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. Each client instance has its own dedicated environment session on the server. """ def _step_payload(self, action: ConstraintAction) -> Dict: """ Convert ConstraintAction to JSON payload for step message. """ return { "ast_output": action.ast_output, } def _parse_result(self, payload: Dict) -> StepResult[ConstraintObservation]: """ Parse server response into StepResult[ConstraintObservation]. """ if isinstance(payload, str): raise ValueError(f"Server returned an error string instead of JSON: {payload}") obs_data = payload.get("observation", {}) if isinstance(obs_data, str): obs_data = {} observation = ConstraintObservation( prompt=obs_data.get("prompt", ""), info=obs_data.get("info", {}), # FIX: Changed from 0 to {} done=payload.get("done", False), reward=payload.get("reward", 0.01), messages=obs_data.get("messages", []) # FIX: Added the missing messages array ) return StepResult( observation=observation, reward=payload.get("reward", 0.01), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> ConstraintState: """ Parse server response into State object. """ if isinstance(payload, str): raise ValueError(f"Server returned an error string instead of JSON: {payload}") return ConstraintState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), # FIX: Added missing step tracking max_steps=payload.get("max_steps", 5) # FIX: Added missing max step bounds )