# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Data models for the ProcureRL Environment. The ProcureRL environment is a procurement negotiation RL environment where an LLM agent learns to negotiate against scripted supplier opponents. """ from typing import Optional, List, Dict, Any from pydantic import BaseModel, Field, ConfigDict try: from openenv.core.env_server.types import Action, Observation, State as OpenEnvState except ImportError: OpenEnvState = object class NegotiationAction(BaseModel): model_config = ConfigDict(extra="allow") move_type: str = Field( default="make_offer", description="Choose action: make_offer (propose), accept (take current deal), reject (walk away), bundle (multi-issue offer)", ) terms: Dict[str, Any] = Field( default_factory=lambda: {"price": 45000}, description='For single_issue: {"price": 45000}. For multi_issue: {"price": 45000, "payment_days": 30}. For adversarial: add "support_hours": 100', ) message: str = Field( default="I value our partnership and believe we can reach a fair agreement together.", description="Write a collaborative message. Use: partnership, mutual, flexible, understand, solution. Avoid: demand, final offer, ultimatum", ) def model_post_init(self, *args, **kwargs): valid_moves = ("make_offer", "accept", "reject", "bundle") if self.move_type not in valid_moves: raise ValueError( f"Invalid move_type: {self.move_type}. Must be one of {valid_moves}" ) class NegotiationObservation(BaseModel): model_config = ConfigDict(extra="allow") task_id: str = "" round_number: int = 0 max_rounds: int = 0 supplier_message: str = "" current_offer: Dict[str, Any] = Field(default_factory=dict) last_4_exchanges: List[Dict] = Field(default_factory=list) buyer_constraints: Dict[str, Any] = Field(default_factory=dict) rapport_hint: str = "neutral" done: bool = False reward: Optional[float] = None metadata: Dict[str, Any] = Field(default_factory=dict) class NegotiationState(BaseModel): model_config = ConfigDict(extra="allow", validate_assignment=True) task_id: str = "" episode_id: str = "" round_number: int = 0 step_count: int = 0 # Required by OpenEnv web interface rapport_score: float = 0.5 consecutive_concessions: int = 0 deal_reached: bool = False final_terms: Optional[Dict] = None cumulative_reward: float = 0.0 def __getitem__(self, key): return getattr(self, key) def get(self, key, default=None): return getattr(self, key, default)