File size: 3,012 Bytes
c1be7c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cf4a9f
c1be7c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""ProcureRL Environment Client."""

from typing import Dict, Any

from openenv.core import EnvClient
from openenv.core.client_types import StepResult

from .models import NegotiationAction, NegotiationObservation, NegotiationState


class ProcureRLEnv(
    EnvClient[NegotiationAction, NegotiationObservation, NegotiationState]
):
    """
    Client for the ProcureRL Environment.

    This client maintains a persistent WebSocket connection to the environment server,
    enabling efficient multi-step interactions with lower latency.
    Each client instance has its own dedicated environment session on the server.

    Example:
        >>> with ProcureRLEnv(base_url="http://localhost:7860") as client:
        ...     result = client.reset(task_id="single_issue")
        ...     print(result.observation.supplier_message)
        ...
        ...     action = NegotiationAction(move_type="make_offer", terms={"price": 42000}, message="Let's discuss")
        ...     result = client.step(action)
        ...     print(result.observation.supplier_message)
    """

    def _step_payload(self, action: NegotiationAction) -> Dict[str, Any]:
        return {
            "move_type": action.move_type,
            "terms": action.terms,
            "message": action.message,
        }

    def _parse_result(
        self, payload: Dict[str, Any]
    ) -> StepResult[NegotiationObservation]:
        obs_data = payload.get("observation", {})
        observation = NegotiationObservation(
            task_id=obs_data.get("task_id", ""),
            round_number=obs_data.get("round_number", 0),
            max_rounds=obs_data.get("max_rounds", 0),
            supplier_message=obs_data.get("supplier_message", ""),
            current_offer=obs_data.get("current_offer", {}),
            last_4_exchanges=obs_data.get("last_4_exchanges", []),
            buyer_constraints=obs_data.get("buyer_constraints", {}),
            rapport_hint=obs_data.get("rapport_hint", "neutral"),
            done=obs_data.get("done", False),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward", 0.0),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> NegotiationState:
        return NegotiationState(
            task_id=payload.get("task_id", ""),
            episode_id=payload.get("episode_id", ""),
            round_number=payload.get("round_number", 0),
            rapport_score=payload.get("rapport_score", 0.5),
            consecutive_concessions=payload.get("consecutive_concessions", 0),
            deal_reached=payload.get("deal_reached", False),
            final_terms=payload.get("final_terms"),
            cumulative_reward=payload.get("cumulative_reward", 0.0),
        )