procure-rl / models.py
akshaypulla's picture
Upload folder using huggingface_hub
c1be7c3 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the ProcureRL Environment.
The ProcureRL environment is a procurement negotiation RL environment where
an LLM agent learns to negotiate against scripted supplier opponents.
"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field, ConfigDict
try:
from openenv.core.env_server.types import Action, Observation, State as OpenEnvState
except ImportError:
OpenEnvState = object
class NegotiationAction(BaseModel):
model_config = ConfigDict(extra="allow")
move_type: str = Field(
default="make_offer",
description="Choose action: make_offer (propose), accept (take current deal), reject (walk away), bundle (multi-issue offer)",
)
terms: Dict[str, Any] = Field(
default_factory=lambda: {"price": 45000},
description='For single_issue: {"price": 45000}. For multi_issue: {"price": 45000, "payment_days": 30}. For adversarial: add "support_hours": 100',
)
message: str = Field(
default="I value our partnership and believe we can reach a fair agreement together.",
description="Write a collaborative message. Use: partnership, mutual, flexible, understand, solution. Avoid: demand, final offer, ultimatum",
)
def model_post_init(self, *args, **kwargs):
valid_moves = ("make_offer", "accept", "reject", "bundle")
if self.move_type not in valid_moves:
raise ValueError(
f"Invalid move_type: {self.move_type}. Must be one of {valid_moves}"
)
class NegotiationObservation(BaseModel):
model_config = ConfigDict(extra="allow")
task_id: str = ""
round_number: int = 0
max_rounds: int = 0
supplier_message: str = ""
current_offer: Dict[str, Any] = Field(default_factory=dict)
last_4_exchanges: List[Dict] = Field(default_factory=list)
buyer_constraints: Dict[str, Any] = Field(default_factory=dict)
rapport_hint: str = "neutral"
done: bool = False
reward: Optional[float] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class NegotiationState(BaseModel):
model_config = ConfigDict(extra="allow", validate_assignment=True)
task_id: str = ""
episode_id: str = ""
round_number: int = 0
step_count: int = 0 # Required by OpenEnv web interface
rapport_score: float = 0.5
consecutive_concessions: int = 0
deal_reached: bool = False
final_terms: Optional[Dict] = None
cumulative_reward: float = 0.0
def __getitem__(self, key):
return getattr(self, key)
def get(self, key, default=None):
return getattr(self, key, default)