Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Data models for the ProcureRL Environment. | |
| The ProcureRL environment is a procurement negotiation RL environment where | |
| an LLM agent learns to negotiate against scripted supplier opponents. | |
| """ | |
| from typing import Optional, List, Dict, Any | |
| from pydantic import BaseModel, Field, ConfigDict | |
| try: | |
| from openenv.core.env_server.types import Action, Observation, State as OpenEnvState | |
| except ImportError: | |
| OpenEnvState = object | |
| class NegotiationAction(BaseModel): | |
| model_config = ConfigDict(extra="allow") | |
| move_type: str = Field( | |
| default="make_offer", | |
| description="Choose action: make_offer (propose), accept (take current deal), reject (walk away), bundle (multi-issue offer)", | |
| ) | |
| terms: Dict[str, Any] = Field( | |
| default_factory=lambda: {"price": 45000}, | |
| description='For single_issue: {"price": 45000}. For multi_issue: {"price": 45000, "payment_days": 30}. For adversarial: add "support_hours": 100', | |
| ) | |
| message: str = Field( | |
| default="I value our partnership and believe we can reach a fair agreement together.", | |
| description="Write a collaborative message. Use: partnership, mutual, flexible, understand, solution. Avoid: demand, final offer, ultimatum", | |
| ) | |
| def model_post_init(self, *args, **kwargs): | |
| valid_moves = ("make_offer", "accept", "reject", "bundle") | |
| if self.move_type not in valid_moves: | |
| raise ValueError( | |
| f"Invalid move_type: {self.move_type}. Must be one of {valid_moves}" | |
| ) | |
| class NegotiationObservation(BaseModel): | |
| model_config = ConfigDict(extra="allow") | |
| task_id: str = "" | |
| round_number: int = 0 | |
| max_rounds: int = 0 | |
| supplier_message: str = "" | |
| current_offer: Dict[str, Any] = Field(default_factory=dict) | |
| last_4_exchanges: List[Dict] = Field(default_factory=list) | |
| buyer_constraints: Dict[str, Any] = Field(default_factory=dict) | |
| rapport_hint: str = "neutral" | |
| done: bool = False | |
| reward: Optional[float] = None | |
| metadata: Dict[str, Any] = Field(default_factory=dict) | |
| class NegotiationState(BaseModel): | |
| model_config = ConfigDict(extra="allow", validate_assignment=True) | |
| task_id: str = "" | |
| episode_id: str = "" | |
| round_number: int = 0 | |
| step_count: int = 0 # Required by OpenEnv web interface | |
| rapport_score: float = 0.5 | |
| consecutive_concessions: int = 0 | |
| deal_reached: bool = False | |
| final_terms: Optional[Dict] = None | |
| cumulative_reward: float = 0.0 | |
| def __getitem__(self, key): | |
| return getattr(self, key) | |
| def get(self, key, default=None): | |
| return getattr(self, key, default) | |