Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ProcureRL Environment Client.""" | |
| from typing import Dict, Any | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from .models import NegotiationAction, NegotiationObservation, NegotiationState | |
| class ProcureRLEnv( | |
| EnvClient[NegotiationAction, NegotiationObservation, NegotiationState] | |
| ): | |
| """ | |
| Client for the ProcureRL Environment. | |
| This client maintains a persistent WebSocket connection to the environment server, | |
| enabling efficient multi-step interactions with lower latency. | |
| Each client instance has its own dedicated environment session on the server. | |
| Example: | |
| >>> with ProcureRLEnv(base_url="http://localhost:7860") as client: | |
| ... result = client.reset(task_id="single_issue") | |
| ... print(result.observation.supplier_message) | |
| ... | |
| ... action = NegotiationAction(move_type="make_offer", terms={"price": 42000}, message="Let's discuss") | |
| ... result = client.step(action) | |
| ... print(result.observation.supplier_message) | |
| """ | |
| def _step_payload(self, action: NegotiationAction) -> Dict[str, Any]: | |
| return { | |
| "move_type": action.move_type, | |
| "terms": action.terms, | |
| "message": action.message, | |
| } | |
| def _parse_result( | |
| self, payload: Dict[str, Any] | |
| ) -> StepResult[NegotiationObservation]: | |
| obs_data = payload.get("observation", {}) | |
| observation = NegotiationObservation( | |
| task_id=obs_data.get("task_id", ""), | |
| round_number=obs_data.get("round_number", 0), | |
| max_rounds=obs_data.get("max_rounds", 0), | |
| supplier_message=obs_data.get("supplier_message", ""), | |
| current_offer=obs_data.get("current_offer", {}), | |
| last_4_exchanges=obs_data.get("last_4_exchanges", []), | |
| buyer_constraints=obs_data.get("buyer_constraints", {}), | |
| rapport_hint=obs_data.get("rapport_hint", "neutral"), | |
| done=obs_data.get("done", False), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward", 0.0), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict[str, Any]) -> NegotiationState: | |
| return NegotiationState( | |
| task_id=payload.get("task_id", ""), | |
| episode_id=payload.get("episode_id", ""), | |
| round_number=payload.get("round_number", 0), | |
| rapport_score=payload.get("rapport_score", 0.5), | |
| consecutive_concessions=payload.get("consecutive_concessions", 0), | |
| deal_reached=payload.get("deal_reached", False), | |
| final_terms=payload.get("final_terms"), | |
| cumulative_reward=payload.get("cumulative_reward", 0.0), | |
| ) | |