Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| ProcureRL Environment Implementation. | |
| An OpenEnv-compliant RL environment for procurement negotiation where | |
| an LLM agent learns to negotiate against scripted supplier opponents. | |
| """ | |
| import uuid | |
| from typing import Optional, Dict, Any | |
| try: | |
| from openenv.core.env_server.interfaces import Environment | |
| except ImportError: | |
| Environment = object | |
| try: | |
| from ..models import NegotiationAction, NegotiationObservation, NegotiationState | |
| from ..opponent import ScriptedPersonaOpponent | |
| from ..graders import grade | |
| except ImportError: | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models import NegotiationAction, NegotiationObservation, NegotiationState | |
| from opponent import ScriptedPersonaOpponent | |
| from graders import grade | |
| TASK_CONFIG = { | |
| "single_issue": { | |
| "persona": "cooperative", | |
| "max_rounds": 6, | |
| "buyer_constraints": { | |
| "price": {"target": 36000, "worst": 55000, "budget": 53000} | |
| }, | |
| }, | |
| "multi_issue": { | |
| "persona": "cash_flow_stressed", | |
| "max_rounds": 8, | |
| "buyer_constraints": { | |
| "price": {"target": 40000, "worst": 58000, "budget": 55000}, | |
| "payment_days": {"target": 60, "worst": 30, "preference": 60}, | |
| }, | |
| }, | |
| "adversarial": { | |
| "persona": "aggressive_anchor", | |
| "max_rounds": 10, | |
| "buyer_constraints": { | |
| "price": {"target": 80000, "worst": 120000, "budget": 115000}, | |
| "payment_days": {"target": 60, "worst": 30, "preference": 60}, | |
| "support_hours": {"target": 150, "worst": 80, "preference": 150}, | |
| }, | |
| }, | |
| } | |
| VALID_MOVES = ("make_offer", "accept", "reject", "bundle") | |
| class ProcureRLEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = NegotiationState() | |
| self._opponent = None | |
| self._task_config = None | |
| self._done = False | |
| self._last_offer: Dict[str, Any] = {} | |
| self._consecutive_concessions = 0 | |
| self._prev_agent_price: Optional[float] = None | |
| self._exchanges: list = [] | |
| self._last_info: Dict[str, Any] = {} | |
| def reset( | |
| self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs | |
| ) -> NegotiationObservation: | |
| task_id = kwargs.get("task_id", "single_issue") | |
| seed = seed if seed is not None else 42 | |
| if task_id not in TASK_CONFIG: | |
| obs = self._make_obs( | |
| f"Unknown task: {task_id}. Valid: {list(TASK_CONFIG.keys())}" | |
| ) | |
| obs.done = True | |
| obs.metadata["error"] = f"unknown_task:{task_id}" | |
| return obs | |
| config = TASK_CONFIG[task_id] | |
| self._task_config = config | |
| self._done = False | |
| self._consecutive_concessions = 0 | |
| self._prev_agent_price = None | |
| self._exchanges = [] | |
| self._last_info = {} | |
| opponent_seed = hash((seed, task_id)) % (2**32) | |
| self._opponent = ScriptedPersonaOpponent( | |
| task_id=task_id, seed=opponent_seed, persona=config["persona"] | |
| ) | |
| opening_msg, opening_terms = self._opponent.get_opening_message() | |
| self._last_offer = opening_terms | |
| self._opponent_opening_price = opening_terms.get("price", 52000.0) | |
| self._state = NegotiationState( | |
| task_id=task_id, | |
| episode_id=episode_id or str(uuid.uuid4())[:8], | |
| round_number=0, | |
| step_count=0, | |
| rapport_score=0.5, | |
| consecutive_concessions=0, | |
| deal_reached=False, | |
| final_terms=None, | |
| cumulative_reward=0.0, | |
| ) | |
| self._exchanges.append( | |
| {"role": "supplier", "message": opening_msg, "terms": opening_terms} | |
| ) | |
| return NegotiationObservation( | |
| task_id=task_id, | |
| round_number=0, | |
| max_rounds=config["max_rounds"], | |
| supplier_message=opening_msg, | |
| current_offer=opening_terms, | |
| last_4_exchanges=self._exchanges[-4:], | |
| buyer_constraints=config["buyer_constraints"], | |
| rapport_hint="neutral", | |
| done=False, | |
| ) | |
| def step(self, action: NegotiationAction, **kwargs) -> NegotiationObservation: | |
| self._last_info = {} | |
| if self._done: | |
| obs = self._make_obs("Episode finished. Call reset().") | |
| obs.done = True | |
| obs.metadata["error"] = "episode_done" | |
| return obs | |
| if self._task_config is None: | |
| obs = self._make_obs("Environment not initialized. Call reset() first.") | |
| obs.done = True | |
| obs.metadata["error"] = "not_initialized" | |
| return obs | |
| if not isinstance(action, NegotiationAction): | |
| action_dict = ( | |
| action if isinstance(action, dict) else {"move_type": "make_offer"} | |
| ) | |
| action = NegotiationAction( | |
| move_type=action_dict.get("move_type", "make_offer"), | |
| terms=action_dict.get("terms", {}), | |
| message=action_dict.get("message", ""), | |
| ) | |
| if action.move_type not in VALID_MOVES: | |
| obs = self._make_obs() | |
| obs.metadata["error"] = f"invalid_move_type:{action.move_type}" | |
| return obs | |
| self._state.round_number += 1 | |
| self._state.step_count += 1 | |
| round_num = self._state.round_number | |
| config = self._task_config | |
| max_rounds = config["max_rounds"] | |
| reward = 0.0 | |
| if self._prev_agent_price is not None and "price" in action.terms: | |
| current_price = float(action.terms.get("price", self._prev_agent_price)) | |
| if current_price > self._prev_agent_price: | |
| self._consecutive_concessions += 1 | |
| else: | |
| self._consecutive_concessions = 0 | |
| if "price" in action.terms: | |
| self._prev_agent_price = float(action.terms.get("price")) | |
| self._state.consecutive_concessions = self._consecutive_concessions | |
| if action.move_type in ("make_offer", "bundle"): | |
| opponent_msg, opponent_terms = self._opponent.respond( | |
| agent_message=action.message, | |
| agent_terms=action.terms, | |
| round_number=round_num, | |
| consecutive_concessions=self._consecutive_concessions, | |
| ) | |
| self._exchanges.append( | |
| {"role": "agent", "message": action.message, "terms": action.terms} | |
| ) | |
| if opponent_terms.get("_accepted"): | |
| self._done = True | |
| self._state.deal_reached = True | |
| self._state.final_terms = action.terms | |
| reward = grade( | |
| self._state.task_id, | |
| action.terms, | |
| True, | |
| round_num, | |
| opponent_opening=self._opponent_opening_price, | |
| consecutive_concessions_flag=(self._consecutive_concessions >= 2), | |
| ) | |
| self._state.cumulative_reward = reward | |
| obs = self._make_obs(supplier_message=opponent_msg) | |
| obs.done = True | |
| obs.reward = reward | |
| self._last_info["deal_price"] = action.terms.get("price") | |
| self._exchanges.append( | |
| { | |
| "role": "supplier", | |
| "message": opponent_msg, | |
| "terms": { | |
| k: v | |
| for k, v in opponent_terms.items() | |
| if not k.startswith("_") | |
| }, | |
| } | |
| ) | |
| return obs | |
| self._last_offer = { | |
| k: v for k, v in opponent_terms.items() if not k.startswith("_") | |
| } | |
| self._state.rapport_score = self._opponent.rapport | |
| self._exchanges.append( | |
| {"role": "supplier", "message": opponent_msg, "terms": self._last_offer} | |
| ) | |
| if round_num >= max_rounds: | |
| self._done = True | |
| reward = 0.0 | |
| obs = self._make_obs(supplier_message=opponent_msg) | |
| obs.done = True | |
| obs.reward = reward | |
| self._last_info["error"] = "max_rounds_reached" | |
| return obs | |
| obs = self._make_obs(supplier_message=opponent_msg) | |
| obs.reward = reward | |
| return obs | |
| if action.move_type == "accept": | |
| self._done = True | |
| self._state.deal_reached = True | |
| self._state.final_terms = self._last_offer | |
| reward = grade( | |
| self._state.task_id, | |
| self._last_offer, | |
| True, | |
| round_num, | |
| opponent_opening=self._opponent_opening_price, | |
| consecutive_concessions_flag=(self._consecutive_concessions >= 2), | |
| ) | |
| self._state.cumulative_reward = reward | |
| obs = self._make_obs() | |
| obs.done = True | |
| obs.reward = reward | |
| self._last_info["deal_price"] = self._last_offer.get("price") | |
| return obs | |
| if action.move_type == "reject": | |
| if round_num >= max_rounds: | |
| self._done = True | |
| reward = 0.0 | |
| obs = self._make_obs() | |
| obs.done = True | |
| obs.reward = reward | |
| self._last_info["error"] = "rejected_at_limit" | |
| return obs | |
| obs = self._make_obs() | |
| obs.reward = 0.0 | |
| return obs | |
| obs = self._make_obs() | |
| obs.reward = 0.0 | |
| return obs | |
| def state(self) -> NegotiationState: | |
| return self._state | |
| def close(self) -> None: | |
| pass | |
| def _make_obs(self, supplier_message: str = None) -> NegotiationObservation: | |
| rapport = self._state.rapport_score | |
| if rapport >= 0.65: | |
| hint = "positive" | |
| elif rapport <= 0.35: | |
| hint = "negative" | |
| else: | |
| hint = "neutral" | |
| return NegotiationObservation( | |
| task_id=self._state.task_id or "", | |
| round_number=self._state.round_number, | |
| max_rounds=self._task_config["max_rounds"] if self._task_config else 0, | |
| supplier_message=supplier_message or "", | |
| current_offer=self._last_offer, | |
| last_4_exchanges=self._exchanges[-4:] if self._exchanges else [], | |
| buyer_constraints=self._task_config["buyer_constraints"] | |
| if self._task_config | |
| else {}, | |
| rapport_hint=hint, | |
| done=self._done, | |
| metadata=self._last_info, | |
| ) | |