# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ProcureRL Environment Client.""" from typing import Dict, Any from openenv.core import EnvClient from openenv.core.client_types import StepResult from .models import NegotiationAction, NegotiationObservation, NegotiationState class ProcureRLEnv( EnvClient[NegotiationAction, NegotiationObservation, NegotiationState] ): """ Client for the ProcureRL Environment. This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. Each client instance has its own dedicated environment session on the server. Example: >>> with ProcureRLEnv(base_url="http://localhost:7860") as client: ... result = client.reset(task_id="single_issue") ... print(result.observation.supplier_message) ... ... action = NegotiationAction(move_type="make_offer", terms={"price": 42000}, message="Let's discuss") ... result = client.step(action) ... print(result.observation.supplier_message) """ def _step_payload(self, action: NegotiationAction) -> Dict[str, Any]: return { "move_type": action.move_type, "terms": action.terms, "message": action.message, } def _parse_result( self, payload: Dict[str, Any] ) -> StepResult[NegotiationObservation]: obs_data = payload.get("observation", {}) observation = NegotiationObservation( task_id=obs_data.get("task_id", ""), round_number=obs_data.get("round_number", 0), max_rounds=obs_data.get("max_rounds", 0), supplier_message=obs_data.get("supplier_message", ""), current_offer=obs_data.get("current_offer", {}), last_4_exchanges=obs_data.get("last_4_exchanges", []), buyer_constraints=obs_data.get("buyer_constraints", {}), rapport_hint=obs_data.get("rapport_hint", "neutral"), done=obs_data.get("done", False), ) return StepResult( observation=observation, reward=payload.get("reward", 0.0), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> NegotiationState: return NegotiationState( task_id=payload.get("task_id", ""), episode_id=payload.get("episode_id", ""), round_number=payload.get("round_number", 0), rapport_score=payload.get("rapport_score", 0.5), consecutive_concessions=payload.get("consecutive_concessions", 0), deal_reached=payload.get("deal_reached", False), final_terms=payload.get("final_terms"), cumulative_reward=payload.get("cumulative_reward", 0.0), )