procure-rl / opponent.py
akshaypulla's picture
Upload folder using huggingface_hub
c1be7c3 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Scripted persona opponent for procurement negotiation.
The opponent's behavior is deterministic given a seed AND sensitive to
the agent's language quality via the rapport system.
"""
import random
from dataclasses import dataclass, field
from typing import Dict, Tuple
COLLABORATIVE_SIGNALS = [
"understand",
"partnership",
"mutual",
"together",
"value",
"appreciate",
"flexible",
"work with",
"long-term",
"relationship",
"reasonable",
"fair",
"both",
"solution",
]
AGGRESSIVE_SIGNALS = [
"demand",
"require",
"final offer",
"unacceptable",
"must",
"non-negotiable",
"take it or leave",
"bottom line",
"ultimatum",
"insist",
"refuse",
"absolutely not",
]
PERSONA_TEMPLATES = {
"cooperative": {
"opening": [
"Thanks for reaching out. Our standard pricing for this package is ${target}. Happy to discuss.",
"We value your interest. We're pricing this at ${target} based on current market rates.",
],
"counter": [
"I appreciate you working with us. Based on our costs, ${counter} is where we can be.",
"Thank you for your offer. We can move to ${counter} given our margin requirements.",
],
"near_close": [
"I think we're close. If you can do ${close}, I can get this approved today.",
"We're almost there. ${close} works for our team. Shall we finalize?",
],
"accept": "That works for us. Let's move forward at those terms.",
"reject": "That's below what we can accept, but we want to make this work.",
},
"cash_flow_stressed": {
"opening": [
"Our pricing is ${target}. I should mention — payment timing is particularly important to us this quarter.",
"We're at ${target}. Between us, our finance team has specific requirements around cash flow timing.",
],
"counter": [
"We can move on price if payment terms work for you. ${counter} with your payment preference?",
"Price flexibility depends on receivables timing for us. ${counter} if we can discuss payment terms.",
],
"near_close": [
"If you can do Net-30 on payment, we can get to ${close} on price.",
"Payment timing is our real constraint. ${close} with faster payment terms?",
],
"accept": "Agreed. The payment structure works for our cash flow needs.",
"reject": "The price is tight but we could explore it if payment terms align.",
},
"aggressive_anchor": {
"opening": [
"Our price is ${target}. This reflects our full service quality and market position.",
"We're firm at ${target}. This is based on our cost structure and service level.",
],
"counter": [
"We can go to ${counter}. That's already a significant concession from our position.",
"${counter} is our revised position. We're not in a position to move much further.",
],
"hardening": [
"We've already moved considerably. ${floor} is our absolute position.",
"I need to be direct — we're at ${floor} and that's where we'll stay.",
],
"near_close": [
"Final position: ${close}. We need a decision today.",
"${close} is where we are. This is our best and final offer.",
],
"accept": "Accepted.",
"reject": "That doesn't work. Come back with a serious offer.",
},
}
class ScriptedPersonaOpponent:
def __init__(self, task_id: str, seed: int, persona: str):
self.rng = random.Random(seed)
self.task_id = task_id
self.persona = persona
self.templates = PERSONA_TEMPLATES[persona]
if task_id == "single_issue":
self.price_floor = self.rng.uniform(42000, 46000)
self.price_target = self.price_floor * self.rng.uniform(1.28, 1.38)
elif task_id == "multi_issue":
self.price_floor = self.rng.uniform(40000, 46000)
self.price_target = self.price_floor * self.rng.uniform(1.25, 1.35)
self.payment_preference = self.rng.choice([30, 45, 60])
elif task_id == "adversarial":
self.price_floor = self.rng.uniform(85000, 95000)
self.price_target = self.price_floor * self.rng.uniform(1.30, 1.40)
self.rapport = 0.5
self.concession_count = 0
self.current_position = self.price_target
def update_rapport(self, agent_message: str) -> None:
msg_lower = agent_message.lower()
delta = 0.0
delta += sum(0.08 for w in COLLABORATIVE_SIGNALS if w in msg_lower)
delta -= sum(0.08 for w in AGGRESSIVE_SIGNALS if w in msg_lower)
delta = max(-0.20, min(0.20, delta))
self.rapport = max(0.0, min(1.0, self.rapport + delta))
def get_concession_rate(self) -> float:
base_rates = {
"cooperative": 0.05,
"cash_flow_stressed": 0.07,
"aggressive_anchor": 0.04,
}
base = base_rates[self.persona]
modifier = (self.rapport - 0.5) * base
return max(0.01, base + modifier)
def respond(
self,
agent_message: str,
agent_terms: Dict,
round_number: int,
consecutive_concessions: int,
) -> Tuple[str, Dict]:
self.update_rapport(agent_message)
self.concession_count += 1
agent_price = agent_terms.get("price", 0)
if (
round_number >= 2
and agent_price >= self.price_floor
and self._acceptance_condition(agent_terms)
):
return self.templates["accept"], {**agent_terms, "_accepted": True}
concession = self.get_concession_rate()
if self.persona == "aggressive_anchor" and consecutive_concessions >= 2:
concession = concession * 0.4
template_key = "hardening"
elif round_number >= self._max_rounds() * 0.7:
template_key = "near_close"
else:
template_key = "counter"
new_position = self.current_position * (1 - concession)
new_position = max(self.price_floor, new_position)
self.current_position = new_position
templates_for_key = self.templates.get(template_key, self.templates["counter"])
template = self.rng.choice(templates_for_key)
message = template.replace("${counter}", f"${new_position:,.0f}")
message = message.replace("${floor}", f"${self.price_floor:,.0f}")
message = message.replace("${close}", f"${new_position:,.0f}")
counter_terms = dict(agent_terms)
counter_terms["price"] = round(new_position, 2)
if self.persona == "cash_flow_stressed" and "payment_days" in agent_terms:
if agent_terms["payment_days"] > 60:
message += (
" Though I'll need to flag the payment timing to our finance team."
)
return message, counter_terms
def _acceptance_condition(self, terms: Dict) -> bool:
if self.persona == "cash_flow_stressed":
payment_ok = terms.get("payment_days", 60) <= 45
return payment_ok
return True
def _max_rounds(self) -> int:
return {"single_issue": 6, "multi_issue": 8, "adversarial": 10}[self.task_id]
def get_opening_message(self) -> Tuple[str, Dict]:
template = self.rng.choice(self.templates["opening"])
message = template.replace("${target}", f"${self.price_target:,.0f}")
terms = {"price": round(self.price_target, 2)}
if self.task_id in ["multi_issue", "adversarial"]:
terms["payment_days"] = 90
if self.task_id == "adversarial":
terms["support_hours"] = 80
return message, terms