procure-rl / server /Procure_RL_environment.py
akshaypulla's picture
Upload folder using huggingface_hub
c1be7c3 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
ProcureRL Environment Implementation.
An OpenEnv-compliant RL environment for procurement negotiation where
an LLM agent learns to negotiate against scripted supplier opponents.
"""
import uuid
from typing import Optional, Dict, Any
try:
from openenv.core.env_server.interfaces import Environment
except ImportError:
Environment = object
try:
from ..models import NegotiationAction, NegotiationObservation, NegotiationState
from ..opponent import ScriptedPersonaOpponent
from ..graders import grade
except ImportError:
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import NegotiationAction, NegotiationObservation, NegotiationState
from opponent import ScriptedPersonaOpponent
from graders import grade
TASK_CONFIG = {
"single_issue": {
"persona": "cooperative",
"max_rounds": 6,
"buyer_constraints": {
"price": {"target": 36000, "worst": 55000, "budget": 53000}
},
},
"multi_issue": {
"persona": "cash_flow_stressed",
"max_rounds": 8,
"buyer_constraints": {
"price": {"target": 40000, "worst": 58000, "budget": 55000},
"payment_days": {"target": 60, "worst": 30, "preference": 60},
},
},
"adversarial": {
"persona": "aggressive_anchor",
"max_rounds": 10,
"buyer_constraints": {
"price": {"target": 80000, "worst": 120000, "budget": 115000},
"payment_days": {"target": 60, "worst": 30, "preference": 60},
"support_hours": {"target": 150, "worst": 80, "preference": 150},
},
},
}
VALID_MOVES = ("make_offer", "accept", "reject", "bundle")
class ProcureRLEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = NegotiationState()
self._opponent = None
self._task_config = None
self._done = False
self._last_offer: Dict[str, Any] = {}
self._consecutive_concessions = 0
self._prev_agent_price: Optional[float] = None
self._exchanges: list = []
self._last_info: Dict[str, Any] = {}
def reset(
self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs
) -> NegotiationObservation:
task_id = kwargs.get("task_id", "single_issue")
seed = seed if seed is not None else 42
if task_id not in TASK_CONFIG:
obs = self._make_obs(
f"Unknown task: {task_id}. Valid: {list(TASK_CONFIG.keys())}"
)
obs.done = True
obs.metadata["error"] = f"unknown_task:{task_id}"
return obs
config = TASK_CONFIG[task_id]
self._task_config = config
self._done = False
self._consecutive_concessions = 0
self._prev_agent_price = None
self._exchanges = []
self._last_info = {}
opponent_seed = hash((seed, task_id)) % (2**32)
self._opponent = ScriptedPersonaOpponent(
task_id=task_id, seed=opponent_seed, persona=config["persona"]
)
opening_msg, opening_terms = self._opponent.get_opening_message()
self._last_offer = opening_terms
self._opponent_opening_price = opening_terms.get("price", 52000.0)
self._state = NegotiationState(
task_id=task_id,
episode_id=episode_id or str(uuid.uuid4())[:8],
round_number=0,
step_count=0,
rapport_score=0.5,
consecutive_concessions=0,
deal_reached=False,
final_terms=None,
cumulative_reward=0.0,
)
self._exchanges.append(
{"role": "supplier", "message": opening_msg, "terms": opening_terms}
)
return NegotiationObservation(
task_id=task_id,
round_number=0,
max_rounds=config["max_rounds"],
supplier_message=opening_msg,
current_offer=opening_terms,
last_4_exchanges=self._exchanges[-4:],
buyer_constraints=config["buyer_constraints"],
rapport_hint="neutral",
done=False,
)
def step(self, action: NegotiationAction, **kwargs) -> NegotiationObservation:
self._last_info = {}
if self._done:
obs = self._make_obs("Episode finished. Call reset().")
obs.done = True
obs.metadata["error"] = "episode_done"
return obs
if self._task_config is None:
obs = self._make_obs("Environment not initialized. Call reset() first.")
obs.done = True
obs.metadata["error"] = "not_initialized"
return obs
if not isinstance(action, NegotiationAction):
action_dict = (
action if isinstance(action, dict) else {"move_type": "make_offer"}
)
action = NegotiationAction(
move_type=action_dict.get("move_type", "make_offer"),
terms=action_dict.get("terms", {}),
message=action_dict.get("message", ""),
)
if action.move_type not in VALID_MOVES:
obs = self._make_obs()
obs.metadata["error"] = f"invalid_move_type:{action.move_type}"
return obs
self._state.round_number += 1
self._state.step_count += 1
round_num = self._state.round_number
config = self._task_config
max_rounds = config["max_rounds"]
reward = 0.0
if self._prev_agent_price is not None and "price" in action.terms:
current_price = float(action.terms.get("price", self._prev_agent_price))
if current_price > self._prev_agent_price:
self._consecutive_concessions += 1
else:
self._consecutive_concessions = 0
if "price" in action.terms:
self._prev_agent_price = float(action.terms.get("price"))
self._state.consecutive_concessions = self._consecutive_concessions
if action.move_type in ("make_offer", "bundle"):
opponent_msg, opponent_terms = self._opponent.respond(
agent_message=action.message,
agent_terms=action.terms,
round_number=round_num,
consecutive_concessions=self._consecutive_concessions,
)
self._exchanges.append(
{"role": "agent", "message": action.message, "terms": action.terms}
)
if opponent_terms.get("_accepted"):
self._done = True
self._state.deal_reached = True
self._state.final_terms = action.terms
reward = grade(
self._state.task_id,
action.terms,
True,
round_num,
opponent_opening=self._opponent_opening_price,
consecutive_concessions_flag=(self._consecutive_concessions >= 2),
)
self._state.cumulative_reward = reward
obs = self._make_obs(supplier_message=opponent_msg)
obs.done = True
obs.reward = reward
self._last_info["deal_price"] = action.terms.get("price")
self._exchanges.append(
{
"role": "supplier",
"message": opponent_msg,
"terms": {
k: v
for k, v in opponent_terms.items()
if not k.startswith("_")
},
}
)
return obs
self._last_offer = {
k: v for k, v in opponent_terms.items() if not k.startswith("_")
}
self._state.rapport_score = self._opponent.rapport
self._exchanges.append(
{"role": "supplier", "message": opponent_msg, "terms": self._last_offer}
)
if round_num >= max_rounds:
self._done = True
reward = 0.0
obs = self._make_obs(supplier_message=opponent_msg)
obs.done = True
obs.reward = reward
self._last_info["error"] = "max_rounds_reached"
return obs
obs = self._make_obs(supplier_message=opponent_msg)
obs.reward = reward
return obs
if action.move_type == "accept":
self._done = True
self._state.deal_reached = True
self._state.final_terms = self._last_offer
reward = grade(
self._state.task_id,
self._last_offer,
True,
round_num,
opponent_opening=self._opponent_opening_price,
consecutive_concessions_flag=(self._consecutive_concessions >= 2),
)
self._state.cumulative_reward = reward
obs = self._make_obs()
obs.done = True
obs.reward = reward
self._last_info["deal_price"] = self._last_offer.get("price")
return obs
if action.move_type == "reject":
if round_num >= max_rounds:
self._done = True
reward = 0.0
obs = self._make_obs()
obs.done = True
obs.reward = reward
self._last_info["error"] = "rejected_at_limit"
return obs
obs = self._make_obs()
obs.reward = 0.0
return obs
obs = self._make_obs()
obs.reward = 0.0
return obs
@property
def state(self) -> NegotiationState:
return self._state
def close(self) -> None:
pass
def _make_obs(self, supplier_message: str = None) -> NegotiationObservation:
rapport = self._state.rapport_score
if rapport >= 0.65:
hint = "positive"
elif rapport <= 0.35:
hint = "negative"
else:
hint = "neutral"
return NegotiationObservation(
task_id=self._state.task_id or "",
round_number=self._state.round_number,
max_rounds=self._task_config["max_rounds"] if self._task_config else 0,
supplier_message=supplier_message or "",
current_offer=self._last_offer,
last_4_exchanges=self._exchanges[-4:] if self._exchanges else [],
buyer_constraints=self._task_config["buyer_constraints"]
if self._task_config
else {},
rapport_hint=hint,
done=self._done,
metadata=self._last_info,
)