Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
# ================================================================
|
| 2 |
-
# ANP
|
|
|
|
|
|
|
| 3 |
# ================================================================
|
| 4 |
-
import os, time, math, random, uuid,
|
| 5 |
-
from typing import List, Dict
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
import torch.nn.functional as F
|
| 10 |
-
from torch.utils.data import
|
| 11 |
from torch.optim import AdamW
|
| 12 |
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 13 |
from transformers import BertTokenizerFast
|
|
@@ -20,399 +22,1468 @@ import matplotlib.pyplot as plt
|
|
| 20 |
random.seed(42)
|
| 21 |
torch.manual_seed(42)
|
| 22 |
|
| 23 |
-
# ββ Config
|
| 24 |
-
DEVICE = torch.device("
|
| 25 |
-
|
|
|
|
| 26 |
MSG2IDX = {m: i for i, m in enumerate(MSG_TYPES)}
|
| 27 |
IDX2MSG = {i: m for m, i in MSG2IDX.items()}
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
CAT2IDX = {c: i for i, c in enumerate(CATEGORIES)}
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
D_MODEL = 384
|
| 33 |
N_HEADS = 6
|
| 34 |
N_LAYERS = 6
|
| 35 |
FFN_DIM = 1024
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
|
|
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def generate_sessions(n_sessions: int) -> List[Dict]:
|
| 65 |
-
"""Generates synthetic negotiation data quickly in memory."""
|
| 66 |
all_rows = []
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
lp
|
| 84 |
-
sid = f"SYN-{uuid.uuid4().hex[:6].upper()}"
|
| 85 |
-
turn = 0
|
| 86 |
-
session_rows = []
|
| 87 |
|
| 88 |
def add(party, price, mtype, msg):
|
| 89 |
nonlocal turn
|
| 90 |
turn += 1
|
| 91 |
-
|
| 92 |
-
"session_id":
|
| 93 |
-
"
|
| 94 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
})
|
| 96 |
|
| 97 |
sp = lp
|
| 98 |
-
bp = round(lp * random.uniform(
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
else:
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
return all_rows
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
}
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
class PositionalEncoding(nn.Module):
|
| 177 |
def __init__(self, d: int, max_len: int = 512):
|
| 178 |
super().__init__()
|
| 179 |
self.drop = nn.Dropout(0.1)
|
| 180 |
-
pe
|
| 181 |
pos = torch.arange(max_len).unsqueeze(1).float()
|
| 182 |
-
div = torch.exp(
|
|
|
|
|
|
|
| 183 |
pe[:, 0::2] = torch.sin(pos * div)
|
| 184 |
pe[:, 1::2] = torch.cos(pos * div)
|
| 185 |
self.register_buffer("pe", pe.unsqueeze(0))
|
| 186 |
|
| 187 |
-
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
class NegotiationTransformer(nn.Module):
|
| 190 |
def __init__(self):
|
| 191 |
super().__init__()
|
| 192 |
-
self.emb
|
| 193 |
-
self.pos
|
| 194 |
-
enc_layer
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
self.msg_head = nn.Linear(D_MODEL, len(MSG_TYPES))
|
| 200 |
-
self.px_head
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
def forward(self, ids, mask, party, cat, ofn, tn):
|
| 203 |
-
x
|
| 204 |
-
x
|
| 205 |
cls = x[:, 0]
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
return self.msg_head(f), self.px_head(f).squeeze(1)
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
global GLOBAL_MODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
try:
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
ce, mse = nn.CrossEntropyLoss(), nn.MSELoss()
|
| 228 |
-
|
| 229 |
-
with STATE.lock:
|
| 230 |
-
STATE.total_epochs = epochs
|
| 231 |
-
STATE.losses = []
|
| 232 |
-
|
| 233 |
-
STATE.log("Entering Training Loop (CPU mode).")
|
| 234 |
total_batches = len(loader)
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
model.train()
|
| 238 |
ep_loss = 0.0
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
for i, batch in enumerate(loader):
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 256 |
-
|
|
|
|
| 257 |
ep_loss += loss.item()
|
| 258 |
-
|
| 259 |
sch.step()
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
-
STATE.log("Training complete. Applying weights to Global Model.")
|
| 267 |
model.eval()
|
| 268 |
GLOBAL_MODEL = model
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
except Exception as e:
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
t = threading.Thread(target=_training_thread_target, args=(int(n_sessions), int(epochs), int(batch_size), float(lr)), daemon=True)
|
| 287 |
-
t.start()
|
| 288 |
-
return "Background training thread triggered."
|
| 289 |
-
|
| 290 |
-
# ββ Inference with Pre-built Templates ββββββββββββββββββββββββ
|
| 291 |
-
def _get_template_message(msg_type: str, price: float, item: str, is_buyer: bool) -> str:
|
| 292 |
-
"""The 'Mouth': Translates the Model's strategy (msg_type, price) into prose."""
|
| 293 |
-
px = f"${price:,.2f}"
|
| 294 |
if is_buyer:
|
| 295 |
-
|
| 296 |
-
"offer":
|
| 297 |
-
"counter":
|
| 298 |
-
"accept":
|
| 299 |
-
"reject": "
|
| 300 |
-
"
|
| 301 |
-
"
|
| 302 |
}
|
|
|
|
|
|
|
|
|
|
| 303 |
else:
|
| 304 |
-
|
| 305 |
-
"offer":
|
| 306 |
-
"counter":
|
| 307 |
-
"accept":
|
| 308 |
-
"reject": "
|
| 309 |
-
"
|
| 310 |
-
"
|
| 311 |
}
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
if GLOBAL_MODEL is None:
|
| 316 |
-
return
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
if
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
else:
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
-
# ββ
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
| 369 |
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
with gr.Row():
|
| 372 |
-
n_sessions = gr.Number(value=
|
| 373 |
epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
|
| 374 |
-
batch_size = gr.Slider(
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
tr_btn
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
status_box = gr.Textbox(label="Thread Status", interactive=False)
|
| 381 |
with gr.Row():
|
| 382 |
-
log_box = gr.Textbox(label="
|
| 383 |
plt_out = gr.Plot(label="Loss Curve")
|
|
|
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
fn=refresh_dashboard,
|
| 388 |
-
outputs=[status_box, log_box, plt_out, gr.Textbox(visible=False)]
|
| 389 |
-
)
|
| 390 |
-
|
| 391 |
-
with gr.Tab("Inference Sandbox"):
|
| 392 |
-
ready_indicator = gr.Textbox(label="Model Status", interactive=False)
|
| 393 |
-
gr.Timer(5, active=True).tick(fn=lambda: "β
Ready" if STATE.model_ready else "β Needs Training", outputs=[ready_indicator])
|
| 394 |
-
|
| 395 |
-
with gr.Row():
|
| 396 |
-
d_cat = gr.Dropdown(CATEGORIES, value="used_car", label="Category")
|
| 397 |
-
d_pty = gr.Radio(["Seller","Buyer"], value="Buyer", label="Party to Simulate")
|
| 398 |
with gr.Row():
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
with gr.Row():
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# ================================================================
|
| 2 |
+
# ANP v5 | Bounded Multi-Agent Negotiation + Inventory Tool Use
|
| 3 |
+
# Buyer bounds Β· Seller inventory context Β· Search action head
|
| 4 |
+
# ZOPA tracking Β· Reservation prices Β· Ranked inventory matching
|
| 5 |
# ================================================================
|
| 6 |
+
import os, time, math, random, uuid, gc
|
| 7 |
+
from typing import List, Dict, Tuple, Optional
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 13 |
from torch.optim import AdamW
|
| 14 |
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 15 |
from transformers import BertTokenizerFast
|
|
|
|
| 22 |
random.seed(42)
|
| 23 |
torch.manual_seed(42)
|
| 24 |
|
| 25 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
|
| 28 |
+
MSG_TYPES = ["offer","counter","accept","reject","exit","stall","search"]
|
| 29 |
MSG2IDX = {m: i for i, m in enumerate(MSG_TYPES)}
|
| 30 |
IDX2MSG = {i: m for m, i in MSG2IDX.items()}
|
| 31 |
+
|
| 32 |
+
CATEGORIES = ["used_car","domain_name","freelance_design","saas_license",
|
| 33 |
+
"electronics","bulk_groceries","consulting"]
|
| 34 |
CAT2IDX = {c: i for i, c in enumerate(CATEGORIES)}
|
| 35 |
|
| 36 |
+
BUYER_PERSONAS = ["aggressive","patient","skeptical","impulsive","strategic"]
|
| 37 |
+
SELLER_PERSONAS = ["firm","motivated","anchoring","collaborative","desperate"]
|
| 38 |
+
BPERSONA2IDX = {p: i for i, p in enumerate(BUYER_PERSONAS)}
|
| 39 |
+
SPERSONA2IDX = {p: i for i, p in enumerate(SELLER_PERSONAS)}
|
| 40 |
+
|
| 41 |
+
MAX_LEN = 96
|
| 42 |
D_MODEL = 384
|
| 43 |
N_HEADS = 6
|
| 44 |
N_LAYERS = 6
|
| 45 |
FFN_DIM = 1024
|
| 46 |
|
| 47 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 48 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 49 |
+
torch.backends.cudnn.benchmark = True
|
| 50 |
+
|
| 51 |
+
print(f"Device: {DEVICE}")
|
| 52 |
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
| 53 |
+
GLOBAL_MODEL = None
|
| 54 |
|
| 55 |
+
# ================================================================
|
| 56 |
+
# INVENTORY DATABASE
|
| 57 |
+
# ================================================================
|
| 58 |
+
def _make_inventory() -> List[Dict]:
|
| 59 |
+
inv = []
|
| 60 |
+
templates = {
|
| 61 |
+
"used_car": [
|
| 62 |
+
("2018 Toyota Camry", "Good", 14500, 12800,
|
| 63 |
+
"sunroof,bluetooth,low miles"),
|
| 64 |
+
("2019 Honda Civic", "Excellent", 18500, 16200,
|
| 65 |
+
"one owner,new tires,clean title"),
|
| 66 |
+
("2020 Ford F-150", "Good", 28000, 24500,
|
| 67 |
+
"tow package,crew cab,4WD"),
|
| 68 |
+
("2016 BMW 3 Series", "Fair", 16000, 13500,
|
| 69 |
+
"sport package,leather,sunroof"),
|
| 70 |
+
("2021 Tesla Model 3", "Excellent", 38000, 35000,
|
| 71 |
+
"autopilot,long range,premium audio"),
|
| 72 |
+
("2017 Chevy Silverado", "Good", 22000, 19000,
|
| 73 |
+
"4WD,tow hitch,extended cab"),
|
| 74 |
+
("2015 Honda Accord", "Fair", 11000, 9200,
|
| 75 |
+
"2 owners,new brakes,cloth seats"),
|
| 76 |
+
("2022 Toyota RAV4", "Excellent", 32000, 29500,
|
| 77 |
+
"hybrid,AWD,apple carplay"),
|
| 78 |
+
],
|
| 79 |
+
"electronics": [
|
| 80 |
+
("MacBook Pro 14 M2", "Excellent", 1800, 1600,
|
| 81 |
+
"16GB RAM,512GB SSD,AppleCare"),
|
| 82 |
+
("iPhone 14 Pro", "Good", 900, 780,
|
| 83 |
+
"256GB,space black,minor scratches"),
|
| 84 |
+
("Sony 65in 4K TV", "Excellent", 750, 620,
|
| 85 |
+
"OLED,smart tv,2 years old"),
|
| 86 |
+
("iPad Air Gen5", "Good", 550, 470,
|
| 87 |
+
"wifi+cellular,pencil included"),
|
| 88 |
+
("Gaming PC RTX4070", "Excellent", 1400, 1200,
|
| 89 |
+
"32GB RAM,1TB NVMe,water cooled"),
|
| 90 |
+
("DJI Mavic 3", "Good", 900, 780,
|
| 91 |
+
"4K camera,3 batteries,case"),
|
| 92 |
+
],
|
| 93 |
+
"domain_name": [
|
| 94 |
+
("QuickLoan.io", "Premium", 12000, 9500,
|
| 95 |
+
"fintech,4 years aged,high DA"),
|
| 96 |
+
("FreshMeals.com", "Good", 4500, 3800,
|
| 97 |
+
"food delivery niche,aged 6yr"),
|
| 98 |
+
("TechPulse.net", "Good", 2200, 1800,
|
| 99 |
+
"tech blog ready,clean history"),
|
| 100 |
+
("GreenHome.co", "Excellent", 5500, 4600,
|
| 101 |
+
"eco niche,brandable,short"),
|
| 102 |
+
("RapidShip.io", "Premium", 8000, 6800,
|
| 103 |
+
"logistics niche,exact match"),
|
| 104 |
+
],
|
| 105 |
+
"freelance_design": [
|
| 106 |
+
("Logo + Brand Kit", "Standard", 800, 650,
|
| 107 |
+
"5 concepts,unlimited revisions,source files"),
|
| 108 |
+
("Website Redesign", "Premium", 3500, 2800,
|
| 109 |
+
"5 pages,mobile,figma handoff"),
|
| 110 |
+
("UI/UX App Design", "Premium", 5000, 4200,
|
| 111 |
+
"full wireframes,prototype,design system"),
|
| 112 |
+
("Social Media Pack", "Standard", 600, 480,
|
| 113 |
+
"30 templates,brand colors,canva ready"),
|
| 114 |
+
("Pitch Deck Design", "Standard", 1200, 950,
|
| 115 |
+
"20 slides,animations,2 revisions"),
|
| 116 |
+
],
|
| 117 |
+
"saas_license": [
|
| 118 |
+
("CRM Pro Annual", "Standard", 2400, 1900,
|
| 119 |
+
"unlimited users,API access,support"),
|
| 120 |
+
("Analytics Suite", "Premium", 4800, 3900,
|
| 121 |
+
"real-time,custom dashboards,export"),
|
| 122 |
+
("Project Mgmt Tool", "Standard", 1200, 980,
|
| 123 |
+
"50 users,gantt,integrations"),
|
| 124 |
+
("Email Marketing Pro", "Standard", 960, 780,
|
| 125 |
+
"100k contacts,automation,A/B"),
|
| 126 |
+
],
|
| 127 |
+
"bulk_groceries": [
|
| 128 |
+
("Organic Coffee 50lb", "Fresh", 420, 350,
|
| 129 |
+
"single origin,roasted weekly,wholesale"),
|
| 130 |
+
("Olive Oil 5 Gal", "Premium", 280, 230,
|
| 131 |
+
"extra virgin,cold press,Italian"),
|
| 132 |
+
("Almond Flour 25lb", "Fresh", 180, 145,
|
| 133 |
+
"blanched,gluten free,bulk"),
|
| 134 |
+
("Protein Powder 20lb", "Good", 260, 210,
|
| 135 |
+
"whey isolate,unflavored,NSF cert"),
|
| 136 |
+
],
|
| 137 |
+
"consulting": [
|
| 138 |
+
("SEO Audit + 90 Day Plan", "Standard", 1500, 1200,
|
| 139 |
+
"technical+content,keyword research,monthly report"),
|
| 140 |
+
("Financial Model Build", "Premium", 3500, 2900,
|
| 141 |
+
"3 statement,DCF,scenario analysis"),
|
| 142 |
+
("HR Policy Package", "Standard", 1800, 1450,
|
| 143 |
+
"employee handbook,policies,compliance"),
|
| 144 |
+
("Marketing Strategy Q", "Premium", 4200, 3500,
|
| 145 |
+
"market research,ICP,channel plan"),
|
| 146 |
+
],
|
| 147 |
+
}
|
| 148 |
+
for cat, items in templates.items():
|
| 149 |
+
for (name, cond, ask, res, feats) in items:
|
| 150 |
+
inv.append({
|
| 151 |
+
"id": str(uuid.uuid4().hex[:8]),
|
| 152 |
+
"category": cat,
|
| 153 |
+
"name": name,
|
| 154 |
+
"condition": cond,
|
| 155 |
+
"ask_price": ask,
|
| 156 |
+
"reservation_price": res,
|
| 157 |
+
"features": feats,
|
| 158 |
+
"notes": "",
|
| 159 |
+
})
|
| 160 |
+
return inv
|
| 161 |
+
|
| 162 |
+
INVENTORY: List[Dict] = _make_inventory()
|
| 163 |
+
|
| 164 |
+
def search_inventory(
|
| 165 |
+
category: str,
|
| 166 |
+
max_price: float,
|
| 167 |
+
min_price: float = 0,
|
| 168 |
+
keywords: str = "",
|
| 169 |
+
top_k: int = 4,
|
| 170 |
+
avoids: str = "",
|
| 171 |
+
) -> List[Dict]:
|
| 172 |
+
kws = [k.strip().lower() for k in keywords.split(",") if k.strip()]
|
| 173 |
+
avd = [a.strip().lower() for a in avoids.split(",") if a.strip()]
|
| 174 |
+
results = []
|
| 175 |
+
for item in INVENTORY:
|
| 176 |
+
if item["category"] != category:
|
| 177 |
+
continue
|
| 178 |
+
if item["ask_price"] > max_price * 1.15:
|
| 179 |
+
continue
|
| 180 |
+
if item["ask_price"] < min_price:
|
| 181 |
+
continue
|
| 182 |
+
combined = f"{item['name']} {item['features']} {item['notes']}".lower()
|
| 183 |
+
if any(av in combined for av in avd):
|
| 184 |
+
continue
|
| 185 |
+
kw_score = sum(1 for kw in kws if kw in combined)
|
| 186 |
+
mid = ((max_price + min_price) / 2
|
| 187 |
+
if min_price > 0 else max_price * 0.8)
|
| 188 |
+
price_dist = abs(item["ask_price"] - mid) / max(mid, 1)
|
| 189 |
+
score = kw_score * 2 - price_dist
|
| 190 |
+
results.append({**item, "_score": score})
|
| 191 |
+
results.sort(key=lambda x: x["_score"], reverse=True)
|
| 192 |
+
return results[:top_k]
|
| 193 |
+
|
| 194 |
+
def format_inventory_context(
|
| 195 |
+
items: List[Dict], reveal_floor: bool = False
|
| 196 |
+
) -> str:
|
| 197 |
+
if not items:
|
| 198 |
+
return "No matching inventory found."
|
| 199 |
+
lines = []
|
| 200 |
+
for it in items:
|
| 201 |
+
line = (f"[{it['id']}] {it['name']} | {it['condition']} | "
|
| 202 |
+
f"Ask: ${it['ask_price']:,} | Features: {it['features']}")
|
| 203 |
+
if reveal_floor:
|
| 204 |
+
line += f" | Floor: ${it['reservation_price']:,}"
|
| 205 |
+
lines.append(line)
|
| 206 |
+
return "\n".join(lines)
|
| 207 |
+
|
| 208 |
+
# ================================================================
|
| 209 |
+
# TEMPLATES
|
| 210 |
+
# ================================================================
|
| 211 |
+
TEMPLATES = {
|
| 212 |
+
"seller_open_firm": [
|
| 213 |
+
"I've had this {item} listed and I'm firm at ${p:,.0f}. "
|
| 214 |
+
"It's priced fairly for the condition.",
|
| 215 |
+
"The market supports ${p:,.0f} for a {item} like this. "
|
| 216 |
+
"I've done my research.",
|
| 217 |
+
"Asking ${p:,.0f} for the {item}. I'm not in a rush β "
|
| 218 |
+
"prefer not to negotiate far from that.",
|
| 219 |
+
],
|
| 220 |
+
"seller_open_motivated": [
|
| 221 |
+
"I'm listing the {item} at ${p:,.0f} but open to reasonable offers. "
|
| 222 |
+
"I'd like to move this quickly.",
|
| 223 |
+
"Got this {item} up for ${p:,.0f}. "
|
| 224 |
+
"Motivated to sell β make me an offer.",
|
| 225 |
+
"Selling the {item} at ${p:,.0f}. "
|
| 226 |
+
"I have flexibility if you're serious about buying today.",
|
| 227 |
+
],
|
| 228 |
+
"seller_counter_hold": [
|
| 229 |
+
"I appreciate the offer but I can't go below ${p:,.0f}. "
|
| 230 |
+
"That's really my floor.",
|
| 231 |
+
"I hear you, but ${p:,.0f} is already a stretch. "
|
| 232 |
+
"I have other interested buyers closer to asking.",
|
| 233 |
+
"That doesn't quite work. I could come to ${p:,.0f} "
|
| 234 |
+
"but that's genuinely as low as I go.",
|
| 235 |
+
],
|
| 236 |
+
"seller_counter_concede": [
|
| 237 |
+
"Alright, I can meet you a bit closer β how does ${p:,.0f} sound?",
|
| 238 |
+
"I've thought about it and I can work with ${p:,.0f} "
|
| 239 |
+
"if we can close today.",
|
| 240 |
+
"Let me split the difference with you. ${p:,.0f} β fair?",
|
| 241 |
+
],
|
| 242 |
+
"seller_stall": [
|
| 243 |
+
"Let me think on that overnight. "
|
| 244 |
+
"I want to make sure I'm not leaving too much on the table.",
|
| 245 |
+
"I've got another showing tomorrow. "
|
| 246 |
+
"Give me until then to decide if your number works.",
|
| 247 |
+
"I need to check with my partner before I commit to that price.",
|
| 248 |
+
],
|
| 249 |
+
"seller_reject": [
|
| 250 |
+
"I can't do that price β it doesn't cover what I have into this.",
|
| 251 |
+
"That's too far from asking. I'd rather hold onto it.",
|
| 252 |
+
"I appreciate you trying but that number doesn't work for me at all.",
|
| 253 |
+
],
|
| 254 |
+
"seller_return_after_walkaway": [
|
| 255 |
+
"Hey, I've been thinking. The other buyer fell through β "
|
| 256 |
+
"would you still do ${p:,.0f}?",
|
| 257 |
+
"Circling back β other deal didn't pan out. "
|
| 258 |
+
"If ${p:,.0f} is still on the table I'd like to make it work.",
|
| 259 |
+
"The showing yesterday didn't go anywhere. "
|
| 260 |
+
"I'm willing to revisit your ${p:,.0f}.",
|
| 261 |
+
],
|
| 262 |
+
"seller_urgency": [
|
| 263 |
+
"Someone else is coming to look this weekend. "
|
| 264 |
+
"If you want it at ${p:,.0f} I need to know by tomorrow.",
|
| 265 |
+
"Just so you know I've got two other people interested. "
|
| 266 |
+
"First right of refusal at ${p:,.0f}.",
|
| 267 |
+
"My situation has changed and I need to close this week. "
|
| 268 |
+
"${p:,.0f} only if we finalize today.",
|
| 269 |
+
],
|
| 270 |
+
"seller_accept": [
|
| 271 |
+
"You know what, ${p:,.0f} works. Let's do it.",
|
| 272 |
+
"Deal. ${p:,.0f} and it's yours.",
|
| 273 |
+
"Alright, I'll take ${p:,.0f}. When can you pick it up?",
|
| 274 |
+
],
|
| 275 |
+
"seller_exit": [
|
| 276 |
+
"I don't think we're going to get there on price. "
|
| 277 |
+
"Good luck with your search.",
|
| 278 |
+
"We're too far apart. I'm going to wait for a better offer.",
|
| 279 |
+
"I appreciate the interest but this isn't going to work "
|
| 280 |
+
"at your number.",
|
| 281 |
+
],
|
| 282 |
+
"seller_search": [
|
| 283 |
+
"Let me check if I have something that better fits "
|
| 284 |
+
"what you're describing.",
|
| 285 |
+
"Hold on β I think I may have another option in my inventory "
|
| 286 |
+
"that suits your needs.",
|
| 287 |
+
"I want to make sure I'm showing you the best match. "
|
| 288 |
+
"Let me pull some alternatives.",
|
| 289 |
+
],
|
| 290 |
+
"buyer_open_aggressive": [
|
| 291 |
+
"I'll offer ${p:,.0f} and that's already above what I was "
|
| 292 |
+
"planning to spend.",
|
| 293 |
+
"I can do ${p:,.0f} cash today. "
|
| 294 |
+
"I know that's low but I need to stay in my budget.",
|
| 295 |
+
"First and best offer: ${p:,.0f}. "
|
| 296 |
+
"I've seen similar {item}s go for less.",
|
| 297 |
+
],
|
| 298 |
+
"buyer_open_strategic": [
|
| 299 |
+
"I've done some research on {item} values in this market. "
|
| 300 |
+
"Based on comps I think ${p:,.0f} is fair.",
|
| 301 |
+
"I'm genuinely interested. I'd like to start at ${p:,.0f} β "
|
| 302 |
+
"I think there's a deal here.",
|
| 303 |
+
"Serious buyer, ready to close fast. "
|
| 304 |
+
"With that in mind, ${p:,.0f}.",
|
| 305 |
+
],
|
| 306 |
+
"buyer_counter_nibble": [
|
| 307 |
+
"Getting closer. Can you do ${p:,.0f}? "
|
| 308 |
+
"That's where I need to be to feel good about the deal.",
|
| 309 |
+
"I'd say yes at ${p:,.0f}. "
|
| 310 |
+
"Throw in the extras and I'll pull the trigger right now.",
|
| 311 |
+
"If you can get to ${p:,.0f} I won't waste any more of "
|
| 312 |
+
"your time β deal done.",
|
| 313 |
+
],
|
| 314 |
+
"buyer_counter_hold": [
|
| 315 |
+
"I've thought about it and I'm still at ${p:,.0f}. "
|
| 316 |
+
"That's genuinely what this is worth to me.",
|
| 317 |
+
"My budget hasn't changed. ${p:,.0f} is the number.",
|
| 318 |
+
"I hear you on the other buyers but ${p:,.0f} is my ceiling.",
|
| 319 |
+
],
|
| 320 |
+
"buyer_stall": [
|
| 321 |
+
"I need to sleep on it. "
|
| 322 |
+
"I'm also looking at a couple other options this week.",
|
| 323 |
+
"Let me talk to my partner tonight and get back to you tomorrow.",
|
| 324 |
+
"I'm not going to rush into this. Give me a day or two.",
|
| 325 |
+
],
|
| 326 |
+
"buyer_walkaway": [
|
| 327 |
+
"I don't think we're going to get there. "
|
| 328 |
+
"Thanks for your time β good luck with the sale.",
|
| 329 |
+
"I'm going to pass. The price just doesn't work for what I need.",
|
| 330 |
+
"Going to look at other options. "
|
| 331 |
+
"If your price changes, feel free to reach out.",
|
| 332 |
+
],
|
| 333 |
+
"buyer_return_after_walkaway": [
|
| 334 |
+
"Hey, been thinking about the {item} since we talked. "
|
| 335 |
+
"Is ${p:,.0f} still the best you can do?",
|
| 336 |
+
"Still have the {item} available? "
|
| 337 |
+
"I might stretch to ${p:,.0f} if we can close quickly.",
|
| 338 |
+
"Came back because I couldn't find anything comparable. "
|
| 339 |
+
"Would you take ${p:,.0f}?",
|
| 340 |
+
],
|
| 341 |
+
"buyer_accept": [
|
| 342 |
+
"Alright, you've got a deal at ${p:,.0f}.",
|
| 343 |
+
"Fine, ${p:,.0f}. Let's stop going back and forth β I'll take it.",
|
| 344 |
+
"Done. ${p:,.0f}. When can I come get it?",
|
| 345 |
+
],
|
| 346 |
+
"buyer_reject": [
|
| 347 |
+
"That's still too high. I can't justify that price.",
|
| 348 |
+
"No, that doesn't work. "
|
| 349 |
+
"I'd need to see a significant move to reconsider.",
|
| 350 |
+
"I'm out at that number. "
|
| 351 |
+
"Not what the market is bearing right now.",
|
| 352 |
+
],
|
| 353 |
+
"buyer_deadline": [
|
| 354 |
+
"I need to make a decision by end of day β "
|
| 355 |
+
"can you give me your absolute best price?",
|
| 356 |
+
"My budget approval expires Friday. "
|
| 357 |
+
"If we agree on ${p:,.0f} right now I can move immediately.",
|
| 358 |
+
"I have to make a call today. "
|
| 359 |
+
"Meet me at ${p:,.0f} and we close this out.",
|
| 360 |
+
],
|
| 361 |
+
"buyer_search": [
|
| 362 |
+
"Do you have anything else in this category that might "
|
| 363 |
+
"work better for my needs?",
|
| 364 |
+
"I'm not sure this is the right fit. "
|
| 365 |
+
"Do you have other options I should look at?",
|
| 366 |
+
"Before I decide, do you have alternatives β "
|
| 367 |
+
"maybe different condition or price point?",
|
| 368 |
+
],
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
def _t(key: str, item: str = "", p: float = 0,
|
| 372 |
+
avoid: str = "", must: str = "") -> str:
|
| 373 |
+
return random.choice(TEMPLATES[key]).format(
|
| 374 |
+
item=item, p=p, avoid=avoid, must=must
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# ================================================================
|
| 378 |
+
# STRATEGY PROFILES
|
| 379 |
+
# ================================================================
|
| 380 |
+
BUYER_STRATEGY = {
|
| 381 |
+
"aggressive": {
|
| 382 |
+
"open_discount": (0.55, 0.68), "concession_rate": 0.015,
|
| 383 |
+
"walkaway_prob": 0.35, "return_prob": 0.50,
|
| 384 |
+
"patience": 3, "search_prob": 0.10,
|
| 385 |
+
},
|
| 386 |
+
"patient": {
|
| 387 |
+
"open_discount": (0.72, 0.82), "concession_rate": 0.025,
|
| 388 |
+
"walkaway_prob": 0.15, "return_prob": 0.70,
|
| 389 |
+
"patience": 8, "search_prob": 0.20,
|
| 390 |
+
},
|
| 391 |
+
"skeptical": {
|
| 392 |
+
"open_discount": (0.65, 0.75), "concession_rate": 0.018,
|
| 393 |
+
"walkaway_prob": 0.28, "return_prob": 0.45,
|
| 394 |
+
"patience": 5, "search_prob": 0.30,
|
| 395 |
+
},
|
| 396 |
+
"impulsive": {
|
| 397 |
+
"open_discount": (0.78, 0.88), "concession_rate": 0.040,
|
| 398 |
+
"walkaway_prob": 0.10, "return_prob": 0.30,
|
| 399 |
+
"patience": 2, "search_prob": 0.05,
|
| 400 |
+
},
|
| 401 |
+
"strategic": {
|
| 402 |
+
"open_discount": (0.62, 0.72), "concession_rate": 0.022,
|
| 403 |
+
"walkaway_prob": 0.30, "return_prob": 0.65,
|
| 404 |
+
"patience": 7, "search_prob": 0.25,
|
| 405 |
+
},
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
SELLER_STRATEGY = {
|
| 409 |
+
"firm": {
|
| 410 |
+
"min_discount": 0.93, "concession_rate": 0.008,
|
| 411 |
+
"urgency_prob": 0.15, "return_prob": 0.30, "search_prob": 0.15,
|
| 412 |
+
},
|
| 413 |
+
"motivated": {
|
| 414 |
+
"min_discount": 0.82, "concession_rate": 0.030,
|
| 415 |
+
"urgency_prob": 0.40, "return_prob": 0.60, "search_prob": 0.35,
|
| 416 |
+
},
|
| 417 |
+
"anchoring": {
|
| 418 |
+
"min_discount": 0.90, "concession_rate": 0.010,
|
| 419 |
+
"urgency_prob": 0.25, "return_prob": 0.40, "search_prob": 0.20,
|
| 420 |
+
},
|
| 421 |
+
"collaborative": {
|
| 422 |
+
"min_discount": 0.86, "concession_rate": 0.022,
|
| 423 |
+
"urgency_prob": 0.20, "return_prob": 0.55, "search_prob": 0.40,
|
| 424 |
+
},
|
| 425 |
+
"desperate": {
|
| 426 |
+
"min_discount": 0.75, "concession_rate": 0.045,
|
| 427 |
+
"urgency_prob": 0.60, "return_prob": 0.75, "search_prob": 0.30,
|
| 428 |
+
},
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
# ================================================================
|
| 432 |
+
# DATA GENERATOR
|
| 433 |
+
# ================================================================
|
| 434 |
def generate_sessions(n_sessions: int) -> List[Dict]:
|
|
|
|
| 435 |
all_rows = []
|
| 436 |
+
|
| 437 |
+
for _ in range(int(n_sessions)):
|
| 438 |
+
cat = random.choice(CATEGORIES)
|
| 439 |
+
item = cat.replace("_", " ").title()
|
| 440 |
+
lp = round(random.uniform(500, 25000), -1)
|
| 441 |
+
sid = f"SYN-{uuid.uuid4().hex[:6].upper()}"
|
| 442 |
+
b_persona = random.choice(BUYER_PERSONAS)
|
| 443 |
+
s_persona = random.choice(SELLER_PERSONAS)
|
| 444 |
+
bs = BUYER_STRATEGY[b_persona]
|
| 445 |
+
ss = SELLER_STRATEGY[s_persona]
|
| 446 |
+
turn = 0
|
| 447 |
+
rows = []
|
| 448 |
+
walked = False
|
| 449 |
+
|
| 450 |
+
b_budget = lp * random.uniform(0.85, 1.05)
|
| 451 |
+
b_estimate = lp * random.uniform(0.65, 0.80)
|
| 452 |
+
s_reserve = lp * random.uniform(0.72, 0.88)
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
def add(party, price, mtype, msg):
|
| 455 |
nonlocal turn
|
| 456 |
turn += 1
|
| 457 |
+
rows.append({
|
| 458 |
+
"session_id": sid,
|
| 459 |
+
"turn_number": turn,
|
| 460 |
+
"party": party,
|
| 461 |
+
"category": cat,
|
| 462 |
+
"item": item,
|
| 463 |
+
"list_price": lp,
|
| 464 |
+
"offer_price": round(price, 2),
|
| 465 |
+
"msg_type": mtype,
|
| 466 |
+
"message": msg,
|
| 467 |
+
"buyer_persona": b_persona,
|
| 468 |
+
"seller_persona": s_persona,
|
| 469 |
+
"buyer_budget": b_budget,
|
| 470 |
+
"buyer_estimate": b_estimate,
|
| 471 |
+
"seller_reservation": s_reserve,
|
| 472 |
})
|
| 473 |
|
| 474 |
sp = lp
|
| 475 |
+
bp = round(lp * random.uniform(*bs["open_discount"]), -1)
|
| 476 |
+
|
| 477 |
+
s_tmpl = ("seller_open_motivated"
|
| 478 |
+
if s_persona in ["motivated", "desperate"]
|
| 479 |
+
else "seller_open_firm")
|
| 480 |
+
b_tmpl = ("buyer_open_aggressive"
|
| 481 |
+
if b_persona == "aggressive"
|
| 482 |
+
else "buyer_open_strategic")
|
| 483 |
+
|
| 484 |
+
add(0, sp, "offer", _t(s_tmpl, item=item, p=sp))
|
| 485 |
+
add(1, bp, "counter", _t(b_tmpl, item=item, p=bp))
|
| 486 |
+
|
| 487 |
+
max_turns = random.randint(8, 24)
|
| 488 |
+
prev_sp = sp
|
| 489 |
+
prev_bp = bp
|
| 490 |
+
stall_streak = 0
|
| 491 |
+
|
| 492 |
+
for rnd in range(max_turns):
|
| 493 |
+
gap = sp - bp
|
| 494 |
+
gap_pct = gap / lp if lp > 0 else 0
|
| 495 |
+
|
| 496 |
+
# Natural close
|
| 497 |
+
if gap_pct < 0.03:
|
| 498 |
+
fp = round((sp + bp) / 2, -1)
|
| 499 |
+
if random.random() < 0.75:
|
| 500 |
+
add(random.choice([0, 1]), fp, "accept",
|
| 501 |
+
_t("seller_accept"
|
| 502 |
+
if random.random() < 0.5
|
| 503 |
+
else "buyer_accept", p=fp))
|
| 504 |
+
break
|
| 505 |
+
|
| 506 |
+
# ββ Seller turn βββββββββββββββββββββββββββββββββββ
|
| 507 |
+
if random.random() < ss["search_prob"] and rnd > 1:
|
| 508 |
+
add(0, sp, "search", _t("seller_search"))
|
| 509 |
+
match_p = round(sp * random.uniform(0.88, 0.98), -1)
|
| 510 |
+
add(0, match_p, "counter",
|
| 511 |
+
f"I found something that might work better β "
|
| 512 |
+
f"similar {item} at ${match_p:,.0f} with better "
|
| 513 |
+
f"specs for your needs.")
|
| 514 |
+
sp = match_p
|
| 515 |
+
stall_streak = 0
|
| 516 |
+
elif random.random() < ss["urgency_prob"] and rnd > 1:
|
| 517 |
+
add(0, sp, "stall", _t("seller_urgency", item=item, p=sp))
|
| 518 |
+
stall_streak += 1
|
| 519 |
+
elif gap_pct > 0.30:
|
| 520 |
+
add(0, sp, "reject", _t("seller_reject"))
|
| 521 |
+
elif prev_sp == sp and stall_streak < 2:
|
| 522 |
+
add(0, sp, "stall", _t("seller_stall"))
|
| 523 |
+
stall_streak += 1
|
| 524 |
+
else:
|
| 525 |
+
concede_s = (ss["concession_rate"] * lp
|
| 526 |
+
* random.uniform(0.5, 1.5))
|
| 527 |
+
sp = max(max(bp + gap * 0.15, sp - concede_s), s_reserve)
|
| 528 |
+
sp = round(sp, -1)
|
| 529 |
+
tmpl = ("seller_counter_concede"
|
| 530 |
+
if concede_s > lp * 0.02
|
| 531 |
+
else "seller_counter_hold")
|
| 532 |
+
add(0, sp, "counter", _t(tmpl, p=sp))
|
| 533 |
+
stall_streak = 0
|
| 534 |
+
|
| 535 |
+
prev_sp = sp
|
| 536 |
+
gap = sp - bp
|
| 537 |
+
|
| 538 |
+
# ββ Buyer turn ββββββββββββββββββββββββββββββββββββ
|
| 539 |
+
concede_b = (bs["concession_rate"] * lp
|
| 540 |
+
* random.uniform(0.5, 1.5))
|
| 541 |
+
|
| 542 |
+
if (random.random() < bs["search_prob"]
|
| 543 |
+
and gap_pct > 0.12 and rnd > 1):
|
| 544 |
+
add(1, bp, "search", _t("buyer_search"))
|
| 545 |
+
new_bp = round(bp * random.uniform(1.01, 1.06), -1)
|
| 546 |
+
add(1, new_bp, "counter",
|
| 547 |
+
f"I looked at your alternatives β I could do "
|
| 548 |
+
f"${new_bp:,.0f} for the right {item} with the "
|
| 549 |
+
f"features I need.")
|
| 550 |
+
bp = new_bp
|
| 551 |
+
|
| 552 |
+
elif (not walked
|
| 553 |
+
and random.random() < bs["walkaway_prob"]
|
| 554 |
+
and rnd > 2):
|
| 555 |
+
walked = True
|
| 556 |
+
add(1, bp, "exit", _t("buyer_walkaway"))
|
| 557 |
+
if random.random() < bs["return_prob"]:
|
| 558 |
+
rp = round(bp * 1.04, -1)
|
| 559 |
+
add(1, rp, "counter",
|
| 560 |
+
_t("buyer_return_after_walkaway",
|
| 561 |
+
item=item, p=rp))
|
| 562 |
+
bp = rp
|
| 563 |
+
else:
|
| 564 |
+
break
|
| 565 |
+
|
| 566 |
+
elif rnd > bs["patience"] and random.random() < 0.30:
|
| 567 |
+
bp = min(sp - gap * 0.1, bp + concede_b)
|
| 568 |
+
bp = min(bp, b_budget)
|
| 569 |
+
bp = round(bp, -1)
|
| 570 |
+
add(1, bp, "counter", _t("buyer_deadline", p=bp))
|
| 571 |
+
|
| 572 |
+
elif gap_pct < 0.08 and random.random() < 0.40:
|
| 573 |
+
add(1, bp, "counter", _t("buyer_counter_nibble", p=bp))
|
| 574 |
+
|
| 575 |
+
elif random.random() < 0.15:
|
| 576 |
+
add(1, bp, "stall", _t("buyer_stall"))
|
| 577 |
+
|
| 578 |
+
elif prev_bp == bp and random.random() < 0.35:
|
| 579 |
+
add(1, bp, "counter", _t("buyer_counter_hold", p=bp))
|
| 580 |
+
|
| 581 |
+
else:
|
| 582 |
+
bp = min(bp + concede_b, b_budget)
|
| 583 |
+
bp = min(sp - gap * 0.15, bp)
|
| 584 |
+
bp = round(bp, -1)
|
| 585 |
+
add(1, bp, "counter", _t("buyer_counter_nibble", p=bp))
|
| 586 |
+
|
| 587 |
+
prev_bp = bp
|
| 588 |
+
|
| 589 |
+
if gap / lp > 0.45:
|
| 590 |
+
add(1, bp, "exit", _t("buyer_reject"))
|
| 591 |
+
if random.random() < ss["return_prob"]:
|
| 592 |
+
new_sp = round(sp * 0.94, -1)
|
| 593 |
+
add(0, new_sp, "counter",
|
| 594 |
+
_t("seller_return_after_walkaway", p=new_sp))
|
| 595 |
+
sp = new_sp
|
| 596 |
+
else:
|
| 597 |
+
break
|
| 598 |
else:
|
| 599 |
+
if (sp - bp) / lp < 0.08:
|
| 600 |
+
fp = round((sp + bp) / 2, -1)
|
| 601 |
+
add(random.choice([0, 1]), fp, "accept",
|
| 602 |
+
_t("seller_accept", p=fp))
|
| 603 |
+
else:
|
| 604 |
+
add(1, bp, "exit", _t("buyer_walkaway"))
|
| 605 |
+
|
| 606 |
+
all_rows.extend(rows)
|
| 607 |
+
|
| 608 |
return all_rows
|
| 609 |
|
| 610 |
+
# ================================================================
|
| 611 |
+
# FEATURE EXTRACTION β all list guards in place
|
| 612 |
+
# ================================================================
|
| 613 |
+
def extract_features(turns, idx, lp,
|
| 614 |
+
b_budget=0, b_estimate=0, s_reserve=0):
|
| 615 |
+
hist = turns[:idx]
|
| 616 |
+
if len(hist) < 1:
|
| 617 |
+
return [0.0] * 10
|
| 618 |
+
|
| 619 |
+
sp_prices = [r["offer_price"] for r in hist if int(r["party"]) == 0]
|
| 620 |
+
bp_prices = [r["offer_price"] for r in hist if int(r["party"]) == 1]
|
| 621 |
+
|
| 622 |
+
s_vel = ((sp_prices[-1] - sp_prices[0]) / lp) \
|
| 623 |
+
if len(sp_prices) > 1 else 0.0
|
| 624 |
+
b_vel = ((bp_prices[-1] - bp_prices[0]) / lp) \
|
| 625 |
+
if len(bp_prices) > 1 else 0.0
|
| 626 |
+
|
| 627 |
+
gap_r = ((sp_prices[-1] - bp_prices[-1]) / lp) \
|
| 628 |
+
if (sp_prices and bp_prices) else 1.0
|
| 629 |
+
|
| 630 |
+
s_con = sum(
|
| 631 |
+
max(0, sp_prices[i-1] - sp_prices[i])
|
| 632 |
+
for i in range(1, len(sp_prices))
|
| 633 |
+
) / lp if len(sp_prices) > 1 else 0.0
|
| 634 |
+
|
| 635 |
+
b_con = sum(
|
| 636 |
+
max(0, bp_prices[i] - bp_prices[i-1])
|
| 637 |
+
for i in range(1, len(bp_prices))
|
| 638 |
+
) / lp if len(bp_prices) > 1 else 0.0
|
| 639 |
+
|
| 640 |
+
stalls = (sum(1 for r in hist if r["msg_type"] == "stall")
|
| 641 |
+
/ max(len(hist), 1))
|
| 642 |
+
searches = (sum(1 for r in hist if r["msg_type"] == "search")
|
| 643 |
+
/ max(len(hist), 1))
|
| 644 |
+
|
| 645 |
+
# Bound-relative β guarded against empty lists
|
| 646 |
+
budget_dist = min(
|
| 647 |
+
(bp_prices[-1] - b_estimate) / max(b_budget - b_estimate, 1), 2.0
|
| 648 |
+
) if (b_budget > 0 and bp_prices) else 0.0
|
| 649 |
+
|
| 650 |
+
floor_dist = min(
|
| 651 |
+
(sp_prices[-1] - s_reserve) / max(lp - s_reserve, 1), 1.5
|
| 652 |
+
) if (s_reserve > 0 and sp_prices) else 0.5
|
| 653 |
+
|
| 654 |
+
turns_norm = min(idx / 25.0, 1.0)
|
| 655 |
+
|
| 656 |
+
return [
|
| 657 |
+
float(s_vel - b_vel),
|
| 658 |
+
float(min(max(gap_r, 0.0), 2.0)),
|
| 659 |
+
float(min(s_con, 2.0)),
|
| 660 |
+
float(min(b_con, 2.0)),
|
| 661 |
+
float(stalls),
|
| 662 |
+
float(searches),
|
| 663 |
+
float(budget_dist),
|
| 664 |
+
float(floor_dist),
|
| 665 |
+
float(turns_norm),
|
| 666 |
+
0.0,
|
| 667 |
+
]
|
| 668 |
+
|
| 669 |
+
# ================================================================
|
| 670 |
+
# DATASET BUILDER β selective pin_memory (small tensors only)
|
| 671 |
+
# ================================================================
|
| 672 |
+
def build_pinned_dataset(rows: List[Dict]) -> TensorDataset:
|
| 673 |
+
sessions = {}
|
| 674 |
+
for r in rows:
|
| 675 |
+
sessions.setdefault(r["session_id"], []).append(r)
|
| 676 |
+
|
| 677 |
+
(texts, party_l, cat_l, ofn_l, tn_l,
|
| 678 |
+
msg_l, pt_l, bp_l, sp_l, mom_l) = ([] for _ in range(10))
|
| 679 |
+
|
| 680 |
+
for turns in sessions.values():
|
| 681 |
+
turns = sorted(turns, key=lambda x: int(x["turn_number"]))
|
| 682 |
+
lp = float(turns[0]["list_price"])
|
| 683 |
+
if lp <= 0:
|
| 684 |
+
continue
|
| 685 |
+
b_bud = float(turns[0].get("buyer_budget", lp))
|
| 686 |
+
b_est = float(turns[0].get("buyer_estimate", lp * 0.75))
|
| 687 |
+
s_res = float(turns[0].get("seller_reservation", lp * 0.80))
|
| 688 |
+
|
| 689 |
+
for i in range(1, len(turns)):
|
| 690 |
+
tgt = turns[i]
|
| 691 |
+
recent = turns[max(0, i-3):i]
|
| 692 |
+
text = " [SEP] ".join(
|
| 693 |
+
f"{'S' if int(t['party'])==0 else 'B'}: {t['message']}"
|
| 694 |
+
for t in recent
|
| 695 |
+
)
|
| 696 |
+
mom = extract_features(turns, i, lp, b_bud, b_est, s_res)
|
| 697 |
+
|
| 698 |
+
texts.append(text)
|
| 699 |
+
party_l.append(int(tgt["party"]))
|
| 700 |
+
cat_l.append(CAT2IDX.get(tgt["category"], 0))
|
| 701 |
+
ofn_l.append(min(float(tgt["offer_price"]) / lp, 3.0))
|
| 702 |
+
tn_l.append(min(int(tgt["turn_number"]) / 25.0, 1.0))
|
| 703 |
+
msg_l.append(MSG2IDX.get(tgt["msg_type"], 1))
|
| 704 |
+
pt_l.append(min(float(tgt["offer_price"]) / lp, 3.0))
|
| 705 |
+
bp_l.append(BPERSONA2IDX.get(
|
| 706 |
+
tgt.get("buyer_persona", "patient"), 1))
|
| 707 |
+
sp_l.append(SPERSONA2IDX.get(
|
| 708 |
+
tgt.get("seller_persona", "firm"), 0))
|
| 709 |
+
mom_l.append(mom)
|
| 710 |
+
|
| 711 |
+
del sessions, rows
|
| 712 |
+
gc.collect()
|
| 713 |
+
|
| 714 |
+
n = len(texts)
|
| 715 |
+
input_ids = torch.empty((n, MAX_LEN), dtype=torch.long)
|
| 716 |
+
attn_mask = torch.empty((n, MAX_LEN), dtype=torch.long)
|
| 717 |
+
|
| 718 |
+
for i in range(0, n, 20000):
|
| 719 |
+
chunk = texts[i : i + 20000]
|
| 720 |
+
enc = tokenizer(
|
| 721 |
+
chunk, max_length=MAX_LEN,
|
| 722 |
+
padding="max_length", truncation=True,
|
| 723 |
+
return_tensors="pt"
|
| 724 |
+
)
|
| 725 |
+
input_ids[i : i + 20000] = enc["input_ids"]
|
| 726 |
+
attn_mask[i : i + 20000] = enc["attention_mask"]
|
| 727 |
+
|
| 728 |
+
del texts
|
| 729 |
+
gc.collect()
|
| 730 |
+
|
| 731 |
+
tensors = dict(
|
| 732 |
+
ids = input_ids,
|
| 733 |
+
mask = attn_mask,
|
| 734 |
+
pty = torch.tensor(party_l, dtype=torch.long),
|
| 735 |
+
cat = torch.tensor(cat_l, dtype=torch.long),
|
| 736 |
+
ofn = torch.tensor(ofn_l, dtype=torch.float),
|
| 737 |
+
tn = torch.tensor(tn_l, dtype=torch.float),
|
| 738 |
+
mt = torch.tensor(msg_l, dtype=torch.long),
|
| 739 |
+
pt = torch.tensor(pt_l, dtype=torch.float),
|
| 740 |
+
bp = torch.tensor(bp_l, dtype=torch.long),
|
| 741 |
+
sp = torch.tensor(sp_l, dtype=torch.long),
|
| 742 |
+
mom = torch.tensor(mom_l, dtype=torch.float),
|
| 743 |
+
)
|
| 744 |
+
del party_l, cat_l, ofn_l, tn_l, msg_l, pt_l, bp_l, sp_l, mom_l
|
| 745 |
+
gc.collect()
|
| 746 |
+
|
| 747 |
+
# ββ Selective pin_memory ββββββββββββββββββββββββββββββββββ
|
| 748 |
+
# ids + mask are ~400 MB each β pinning them causes the CUDA
|
| 749 |
+
# driver to reserve matching GPU-side DMA staging buffers,
|
| 750 |
+
# blowing VRAM before training even starts.
|
| 751 |
+
# Only pin the small scalar tensors; they transfer instantly
|
| 752 |
+
# and get the DMA benefit without the memory cost.
|
| 753 |
+
if DEVICE.type == "cuda":
|
| 754 |
+
SMALL_KEYS = {"pty","cat","ofn","tn","mt","pt","bp","sp","mom"}
|
| 755 |
+
tensors = {
|
| 756 |
+
k: (v.pin_memory() if k in SMALL_KEYS else v)
|
| 757 |
+
for k, v in tensors.items()
|
| 758 |
}
|
| 759 |
|
| 760 |
+
return TensorDataset(*tensors.values())
|
| 761 |
+
|
| 762 |
+
# ================================================================
|
| 763 |
+
# MODEL
|
| 764 |
+
# ================================================================
|
| 765 |
class PositionalEncoding(nn.Module):
|
| 766 |
def __init__(self, d: int, max_len: int = 512):
|
| 767 |
super().__init__()
|
| 768 |
self.drop = nn.Dropout(0.1)
|
| 769 |
+
pe = torch.zeros(max_len, d)
|
| 770 |
pos = torch.arange(max_len).unsqueeze(1).float()
|
| 771 |
+
div = torch.exp(
|
| 772 |
+
torch.arange(0, d, 2).float() * (-math.log(10000.0) / d)
|
| 773 |
+
)
|
| 774 |
pe[:, 0::2] = torch.sin(pos * div)
|
| 775 |
pe[:, 1::2] = torch.cos(pos * div)
|
| 776 |
self.register_buffer("pe", pe.unsqueeze(0))
|
| 777 |
|
| 778 |
+
def forward(self, x):
|
| 779 |
+
return self.drop(x + self.pe[:, :x.size(1)])
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
class MomentumEncoder(nn.Module):
|
| 783 |
+
def __init__(self, in_dim: int = 10, out_dim: int = 48):
|
| 784 |
+
super().__init__()
|
| 785 |
+
self.net = nn.Sequential(
|
| 786 |
+
nn.Linear(in_dim, 64), nn.GELU(),
|
| 787 |
+
nn.Linear(64, out_dim)
|
| 788 |
+
)
|
| 789 |
+
def forward(self, x): return self.net(x)
|
| 790 |
+
|
| 791 |
|
| 792 |
class NegotiationTransformer(nn.Module):
|
| 793 |
def __init__(self):
|
| 794 |
super().__init__()
|
| 795 |
+
self.emb = nn.Embedding(30522, D_MODEL, padding_idx=0)
|
| 796 |
+
self.pos = PositionalEncoding(D_MODEL)
|
| 797 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 798 |
+
D_MODEL, N_HEADS, FFN_DIM,
|
| 799 |
+
dropout=0.1, batch_first=True, norm_first=True
|
| 800 |
+
)
|
| 801 |
+
self.encoder = nn.TransformerEncoder(enc_layer, N_LAYERS)
|
| 802 |
+
self.p_emb = nn.Embedding(2, 32)
|
| 803 |
+
self.c_emb = nn.Embedding(len(CATEGORIES), 64)
|
| 804 |
+
self.bp_emb = nn.Embedding(len(BUYER_PERSONAS), 32)
|
| 805 |
+
self.sp_emb = nn.Embedding(len(SELLER_PERSONAS), 32)
|
| 806 |
+
self.mom_enc = MomentumEncoder(10, 48)
|
| 807 |
+
total_ctx = D_MODEL + 32 + 64 + 32 + 32 + 48 + 2
|
| 808 |
+
self.fusion = nn.Sequential(
|
| 809 |
+
nn.Linear(total_ctx, D_MODEL), nn.GELU(), nn.Dropout(0.1)
|
| 810 |
+
)
|
| 811 |
self.msg_head = nn.Linear(D_MODEL, len(MSG_TYPES))
|
| 812 |
+
self.px_head = nn.Sequential(
|
| 813 |
+
nn.Linear(D_MODEL, 128), nn.GELU(),
|
| 814 |
+
nn.Linear(128, 1), nn.Softplus()
|
| 815 |
+
)
|
| 816 |
|
| 817 |
+
def forward(self, ids, mask, party, cat, ofn, tn, bp, sp, mom):
|
| 818 |
+
x = self.pos(self.emb(ids))
|
| 819 |
+
x = self.encoder(x, src_key_padding_mask=(mask == 0))
|
| 820 |
cls = x[:, 0]
|
| 821 |
+
ctx = torch.cat([
|
| 822 |
+
cls,
|
| 823 |
+
self.p_emb(party),
|
| 824 |
+
self.c_emb(cat),
|
| 825 |
+
self.bp_emb(bp),
|
| 826 |
+
self.sp_emb(sp),
|
| 827 |
+
self.mom_enc(mom),
|
| 828 |
+
torch.stack([ofn, tn], dim=1),
|
| 829 |
+
], dim=1)
|
| 830 |
+
f = self.fusion(ctx)
|
| 831 |
return self.msg_head(f), self.px_head(f).squeeze(1)
|
| 832 |
|
| 833 |
+
|
| 834 |
+
class AsymmetricNegotiationLoss(nn.Module):
|
| 835 |
+
def __init__(self):
|
| 836 |
+
super().__init__()
|
| 837 |
+
# [offer, counter, accept, reject, exit, stall, search]
|
| 838 |
+
self.seller_w = torch.tensor([1.0,1.0,1.5,1.2,1.3,0.8,1.1])
|
| 839 |
+
self.buyer_w = torch.tensor([1.0,1.0,1.3,1.0,1.2,0.9,1.2])
|
| 840 |
+
|
| 841 |
+
def forward(self, mt_logits, mt_targets, px_pred, px_targets, party):
|
| 842 |
+
dev = mt_logits.device
|
| 843 |
+
sw = self.seller_w.to(dev)
|
| 844 |
+
bw = self.buyer_w.to(dev)
|
| 845 |
+
loss_mt = torch.zeros(mt_logits.size(0), device=dev)
|
| 846 |
+
sm = (party == 0)
|
| 847 |
+
bm = (party == 1)
|
| 848 |
+
if sm.any():
|
| 849 |
+
loss_mt[sm] = F.cross_entropy(
|
| 850 |
+
mt_logits[sm], mt_targets[sm],
|
| 851 |
+
weight=sw, reduction="none"
|
| 852 |
+
)
|
| 853 |
+
if bm.any():
|
| 854 |
+
loss_mt[bm] = F.cross_entropy(
|
| 855 |
+
mt_logits[bm], mt_targets[bm],
|
| 856 |
+
weight=bw, reduction="none"
|
| 857 |
+
)
|
| 858 |
+
return loss_mt.mean() + 0.5 * F.mse_loss(px_pred, px_targets)
|
| 859 |
+
|
| 860 |
+
# ================================================================
|
| 861 |
+
# PLOT
|
| 862 |
+
# ================================================================
|
| 863 |
+
def plot_curve(losses):
|
| 864 |
+
fig, ax = plt.subplots(figsize=(6, 3))
|
| 865 |
+
if losses:
|
| 866 |
+
ax.plot(range(1, len(losses)+1), losses, "b-o", markersize=4)
|
| 867 |
+
ax.set_title("Training Loss")
|
| 868 |
+
else:
|
| 869 |
+
ax.text(0.5, 0.5, "No data yet",
|
| 870 |
+
ha="center", va="center", alpha=0.5)
|
| 871 |
+
ax.grid(alpha=0.3)
|
| 872 |
+
plt.tight_layout()
|
| 873 |
+
return fig
|
| 874 |
+
|
| 875 |
+
# ================================================================
|
| 876 |
+
# TRAINING
|
| 877 |
+
# ================================================================
|
| 878 |
+
def run_training(n_sessions, epochs, batch_size, lr):
|
| 879 |
global GLOBAL_MODEL
|
| 880 |
+
logs = []
|
| 881 |
+
|
| 882 |
+
def log(msg):
|
| 883 |
+
ts = time.strftime("%H:%M:%S")
|
| 884 |
+
line = f"[{ts}] {msg}"
|
| 885 |
+
logs.append(line)
|
| 886 |
+
if len(logs) > 20:
|
| 887 |
+
logs.pop(0)
|
| 888 |
+
print(line)
|
| 889 |
+
return "\n".join(logs)
|
| 890 |
+
|
| 891 |
try:
|
| 892 |
+
batch_size = int(batch_size)
|
| 893 |
+
log_txt = log(f"Generating {int(n_sessions):,} sessions...")
|
| 894 |
+
yield "π‘ Generating...", log_txt, plot_curve([]), "β Needs Training"
|
| 895 |
+
|
| 896 |
+
rows = generate_sessions(int(n_sessions))
|
| 897 |
+
log_txt = log(f"Generated {len(rows):,} rows. Building dataset...")
|
| 898 |
+
yield "π‘ Tokenizing...", log_txt, plot_curve([]), "β Needs Training"
|
| 899 |
+
|
| 900 |
+
dataset = build_pinned_dataset(rows)
|
| 901 |
+
loader = DataLoader(
|
| 902 |
+
dataset, batch_size=batch_size,
|
| 903 |
+
shuffle=True, num_workers=0,
|
| 904 |
+
pin_memory=False, drop_last=True
|
| 905 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
total_batches = len(loader)
|
| 907 |
+
|
| 908 |
+
log_txt = log(f"Dataset: {len(dataset):,} samples | "
|
| 909 |
+
f"{total_batches} batches | bs={batch_size}")
|
| 910 |
+
yield "π‘ Building model...", log_txt, plot_curve([]), "β Needs Training"
|
| 911 |
+
|
| 912 |
+
model = NegotiationTransformer().to(DEVICE)
|
| 913 |
+
crit = AsymmetricNegotiationLoss()
|
| 914 |
+
|
| 915 |
+
if hasattr(torch, "compile") and DEVICE.type == "cuda":
|
| 916 |
+
try:
|
| 917 |
+
model = torch.compile(model, backend="cudagraphs")
|
| 918 |
+
log_txt = log("torch.compile (cudagraphs) applied")
|
| 919 |
+
except Exception as ce:
|
| 920 |
+
log_txt = log(f"compile skipped: {ce}")
|
| 921 |
+
|
| 922 |
+
opt = AdamW(model.parameters(), lr=float(lr), weight_decay=1e-2)
|
| 923 |
+
sch = CosineAnnealingLR(opt, T_max=int(epochs))
|
| 924 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 925 |
+
losses = []
|
| 926 |
+
|
| 927 |
+
log_txt = log("π Training started")
|
| 928 |
+
yield "π’ Training...", log_txt, plot_curve([]), "β Needs Training"
|
| 929 |
+
|
| 930 |
+
for ep in range(int(epochs)):
|
| 931 |
model.train()
|
| 932 |
ep_loss = 0.0
|
| 933 |
+
t0 = time.time()
|
| 934 |
+
|
|
|
|
| 935 |
for i, batch in enumerate(loader):
|
| 936 |
+
(b_ids, b_mask, b_pty, b_cat, b_ofn,
|
| 937 |
+
b_tn, b_mt, b_pt, b_bp, b_sp, b_mom) = [
|
| 938 |
+
t.to(DEVICE, non_blocking=True) for t in batch
|
| 939 |
+
]
|
| 940 |
+
|
| 941 |
+
if i % 100 == 0:
|
| 942 |
+
el = time.time() - t0
|
| 943 |
+
ms_b = (el / max(i, 1)) * 1000
|
| 944 |
+
status = (f"π’ Epoch {ep+1}/{int(epochs)} | "
|
| 945 |
+
f"Batch {i}/{total_batches} | "
|
| 946 |
+
f"{ms_b:.0f}ms/batch")
|
| 947 |
+
log_txt = log(status)
|
| 948 |
+
yield (status, log_txt,
|
| 949 |
+
plot_curve(losses), "β Needs Training")
|
| 950 |
+
|
| 951 |
+
opt.zero_grad(set_to_none=True)
|
| 952 |
+
|
| 953 |
+
with torch.cuda.amp.autocast():
|
| 954 |
+
mt_logits, px_pred = model(
|
| 955 |
+
b_ids, b_mask, b_pty, b_cat,
|
| 956 |
+
b_ofn, b_tn, b_bp, b_sp, b_mom
|
| 957 |
+
)
|
| 958 |
+
loss = crit(mt_logits, b_mt, px_pred, b_pt, b_pty)
|
| 959 |
+
|
| 960 |
+
scaler.scale(loss).backward()
|
| 961 |
+
scaler.unscale_(opt)
|
| 962 |
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 963 |
+
scaler.step(opt)
|
| 964 |
+
scaler.update()
|
| 965 |
ep_loss += loss.item()
|
| 966 |
+
|
| 967 |
sch.step()
|
| 968 |
+
avg = ep_loss / max(total_batches, 1)
|
| 969 |
+
et = time.time() - t0
|
| 970 |
+
losses.append(avg)
|
| 971 |
+
log_txt = log(
|
| 972 |
+
f"Epoch {ep+1}/{int(epochs)} done β "
|
| 973 |
+
f"loss: {avg:.4f} | {et:.1f}s | "
|
| 974 |
+
f"{et/total_batches*1000:.0f}ms/batch"
|
| 975 |
+
)
|
| 976 |
+
yield (f"π’ Epoch {ep+1} done", log_txt,
|
| 977 |
+
plot_curve(losses), "β Needs Training")
|
| 978 |
|
|
|
|
| 979 |
model.eval()
|
| 980 |
GLOBAL_MODEL = model
|
| 981 |
+
log_txt = log("β
Training complete.")
|
| 982 |
+
yield "π΅ Complete", log_txt, plot_curve(losses), "β
Ready"
|
| 983 |
+
|
| 984 |
except Exception as e:
|
| 985 |
+
import traceback
|
| 986 |
+
log_txt = log(f"ERROR: {e}\n{traceback.format_exc()}")
|
| 987 |
+
yield "π΄ ERROR", log_txt, plot_curve([]), "β Needs Training"
|
| 988 |
+
|
| 989 |
+
# ================================================================
|
| 990 |
+
# INFERENCE ENGINE
|
| 991 |
+
# ================================================================
|
| 992 |
+
def _build_message(msg_type, price, item,
|
| 993 |
+
is_buyer, persona, inv_context=""):
|
| 994 |
+
p = price
|
| 995 |
+
if msg_type == "search":
|
| 996 |
+
return _t("buyer_search") if is_buyer else _t("seller_search")
|
| 997 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
if is_buyer:
|
| 999 |
+
m = {
|
| 1000 |
+
"offer": _t("buyer_open_strategic", item=item, p=p),
|
| 1001 |
+
"counter": _t("buyer_counter_nibble", p=p),
|
| 1002 |
+
"accept": _t("buyer_accept", p=p),
|
| 1003 |
+
"reject": _t("buyer_reject"),
|
| 1004 |
+
"exit": _t("buyer_walkaway"),
|
| 1005 |
+
"stall": _t("buyer_stall"),
|
| 1006 |
}
|
| 1007 |
+
if persona == "aggressive":
|
| 1008 |
+
m["offer"] = _t("buyer_open_aggressive", item=item, p=p)
|
| 1009 |
+
m["counter"] = _t("buyer_counter_hold", p=p)
|
| 1010 |
else:
|
| 1011 |
+
m = {
|
| 1012 |
+
"offer": _t("seller_open_firm", item=item, p=p),
|
| 1013 |
+
"counter": _t("seller_counter_hold", p=p),
|
| 1014 |
+
"accept": _t("seller_accept", p=p),
|
| 1015 |
+
"reject": _t("seller_reject"),
|
| 1016 |
+
"exit": _t("seller_exit"),
|
| 1017 |
+
"stall": _t("seller_stall"),
|
| 1018 |
}
|
| 1019 |
+
if persona in ["motivated", "desperate"]:
|
| 1020 |
+
m["offer"] = _t("seller_open_motivated", item=item, p=p)
|
| 1021 |
+
m["counter"] = _t("seller_counter_concede", p=p)
|
| 1022 |
+
if inv_context:
|
| 1023 |
+
m["search"] = (
|
| 1024 |
+
"Let me check my inventory...\n"
|
| 1025 |
+
f"{inv_context}\nWould any of these work for you?"
|
| 1026 |
+
)
|
| 1027 |
+
return m.get(msg_type, f"{msg_type} @ ${p:,.2f}")
|
| 1028 |
+
|
| 1029 |
|
| 1030 |
+
def run_inference_turn(
|
| 1031 |
+
session_state,
|
| 1032 |
+
category, item,
|
| 1033 |
+
list_price, user_price, user_message,
|
| 1034 |
+
user_party, user_persona, ai_persona,
|
| 1035 |
+
buyer_budget, buyer_estimate,
|
| 1036 |
+
buyer_avoids, buyer_must_have,
|
| 1037 |
+
seller_reservation, seller_urgency,
|
| 1038 |
+
):
|
| 1039 |
if GLOBAL_MODEL is None:
|
| 1040 |
+
return (session_state,
|
| 1041 |
+
session_state.get("history_ui", []),
|
| 1042 |
+
"Model not trained.", "", "", "", "")
|
| 1043 |
+
|
| 1044 |
+
lp = float(list_price)
|
| 1045 |
+
is_user_buyer = (user_party == "Buyer")
|
| 1046 |
+
ai_party_int = 0 if is_user_buyer else 1
|
| 1047 |
+
|
| 1048 |
+
# ββ Initialise session ββββββββββββββββββββββββββββββββββββ
|
| 1049 |
+
if not session_state.get("started"):
|
| 1050 |
+
init_bp = (float(buyer_estimate)
|
| 1051 |
+
if float(buyer_estimate) > 0
|
| 1052 |
+
else round(lp * 0.75, -1))
|
| 1053 |
+
session_state = {
|
| 1054 |
+
"started": True,
|
| 1055 |
+
"turn": 0,
|
| 1056 |
+
"sp": lp,
|
| 1057 |
+
"bp": init_bp,
|
| 1058 |
+
"history": [],
|
| 1059 |
+
"history_ui": [],
|
| 1060 |
+
"status": "active",
|
| 1061 |
+
"inv_context": "",
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
if session_state["status"] != "active":
|
| 1065 |
+
return (session_state, session_state["history_ui"],
|
| 1066 |
+
"Session ended β click New Session to restart.",
|
| 1067 |
+
"", "", "", "")
|
| 1068 |
+
|
| 1069 |
+
history = session_state["history"]
|
| 1070 |
+
history_ui = session_state["history_ui"]
|
| 1071 |
+
sp = float(session_state["sp"])
|
| 1072 |
+
bp = float(session_state["bp"])
|
| 1073 |
+
turn = session_state["turn"]
|
| 1074 |
+
|
| 1075 |
+
b_bud = float(buyer_budget) if float(buyer_budget) > 0 else lp
|
| 1076 |
+
b_est = float(buyer_estimate) if float(buyer_estimate) > 0 else lp * 0.75
|
| 1077 |
+
s_res = float(seller_reservation) if float(seller_reservation) > 0 else lp * 0.80
|
| 1078 |
+
|
| 1079 |
+
# ββ Record user turn ββββββββββββββββββββββββββββββββββββββ
|
| 1080 |
+
u_int = 1 if is_user_buyer else 0
|
| 1081 |
+
history.append({
|
| 1082 |
+
"party": u_int,
|
| 1083 |
+
"message": user_message,
|
| 1084 |
+
"offer_price": float(user_price),
|
| 1085 |
+
"msg_type": "counter",
|
| 1086 |
+
"turn_number": turn + 1,
|
| 1087 |
+
})
|
| 1088 |
+
history_ui.append((
|
| 1089 |
+
f"{'π§ You (Buyer)' if is_user_buyer else 'π§ You (Seller)'}"
|
| 1090 |
+
f" [${float(user_price):,.0f}]: {user_message}",
|
| 1091 |
+
None
|
| 1092 |
+
))
|
| 1093 |
+
turn += 1
|
| 1094 |
+
|
| 1095 |
+
if is_user_buyer:
|
| 1096 |
+
bp = float(user_price)
|
| 1097 |
else:
|
| 1098 |
+
sp = float(user_price)
|
| 1099 |
+
|
| 1100 |
+
# ββ Build momentum features βββββββββββββββββββββββββββββββ
|
| 1101 |
+
sp_prices = [r["offer_price"] for r in history if int(r["party"]) == 0]
|
| 1102 |
+
bp_prices = [r["offer_price"] for r in history if int(r["party"]) == 1]
|
| 1103 |
+
|
| 1104 |
+
s_vel = ((sp_prices[-1]-sp_prices[0])/lp) if len(sp_prices)>1 else 0.0
|
| 1105 |
+
b_vel = ((bp_prices[-1]-bp_prices[0])/lp) if len(bp_prices)>1 else 0.0
|
| 1106 |
+
gap_r = ((sp - bp) / lp) if sp > bp else 0.0
|
| 1107 |
+
stalls = (sum(1 for r in history if r["msg_type"] == "stall")
|
| 1108 |
+
/ max(len(history), 1))
|
| 1109 |
+
srch = (sum(1 for r in history if r["msg_type"] == "search")
|
| 1110 |
+
/ max(len(history), 1))
|
| 1111 |
+
b_dist = min((bp - b_est) / max(b_bud - b_est, 1), 2.0) \
|
| 1112 |
+
if (b_bud > 0 and bp_prices) else 0.0
|
| 1113 |
+
f_dist = min((sp - s_res) / max(lp - s_res, 1), 1.5) \
|
| 1114 |
+
if (s_res > 0 and sp_prices) else 0.5
|
| 1115 |
+
|
| 1116 |
+
mom = [
|
| 1117 |
+
float(s_vel - b_vel),
|
| 1118 |
+
float(min(max(gap_r, 0.0), 2.0)),
|
| 1119 |
+
0.0, 0.0,
|
| 1120 |
+
float(stalls), float(srch),
|
| 1121 |
+
float(b_dist), float(f_dist),
|
| 1122 |
+
float(min(turn / 25.0, 1.0)),
|
| 1123 |
+
0.0,
|
| 1124 |
+
]
|
| 1125 |
+
|
| 1126 |
+
# ββ Build text context ββββββββββββββββββββββββββββββββββββ
|
| 1127 |
+
inv_ctx = session_state.get("inv_context", "")
|
| 1128 |
+
recent = history[-3:]
|
| 1129 |
+
text = " [SEP] ".join(
|
| 1130 |
+
f"{'S' if int(r['party'])==0 else 'B'}: {r['message']}"
|
| 1131 |
+
for r in recent
|
| 1132 |
+
)
|
| 1133 |
+
if inv_ctx:
|
| 1134 |
+
text = f"[INV: {inv_ctx[:120]}] " + text
|
| 1135 |
+
|
| 1136 |
+
enc = tokenizer(
|
| 1137 |
+
text, max_length=MAX_LEN,
|
| 1138 |
+
padding="max_length", truncation=True,
|
| 1139 |
+
return_tensors="pt"
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
dev = DEVICE
|
| 1143 |
+
ai_pty_t = torch.tensor([ai_party_int], dtype=torch.long).to(dev)
|
| 1144 |
+
cat_t = torch.tensor([CAT2IDX.get(category, 0)],
|
| 1145 |
+
dtype=torch.long).to(dev)
|
| 1146 |
+
ofn_t = torch.tensor([min(float(user_price)/lp, 3.0)],
|
| 1147 |
+
dtype=torch.float).to(dev)
|
| 1148 |
+
tn_t = torch.tensor([min(turn/25.0, 1.0)],
|
| 1149 |
+
dtype=torch.float).to(dev)
|
| 1150 |
+
bp_idx = BPERSONA2IDX.get(
|
| 1151 |
+
user_persona if is_user_buyer else ai_persona, 1)
|
| 1152 |
+
sp_idx = SPERSONA2IDX.get(
|
| 1153 |
+
ai_persona if is_user_buyer else user_persona, 0)
|
| 1154 |
+
bp_t = torch.tensor([bp_idx], dtype=torch.long).to(dev)
|
| 1155 |
+
sp_t = torch.tensor([sp_idx], dtype=torch.long).to(dev)
|
| 1156 |
+
mom_t = torch.tensor([mom], dtype=torch.float).to(dev)
|
| 1157 |
+
|
| 1158 |
+
with torch.no_grad():
|
| 1159 |
+
mt_logits, px = GLOBAL_MODEL(
|
| 1160 |
+
enc["input_ids"].to(dev),
|
| 1161 |
+
enc["attention_mask"].to(dev),
|
| 1162 |
+
ai_pty_t, cat_t, ofn_t, tn_t, bp_t, sp_t, mom_t
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
mt_idx = mt_logits.argmax(dim=1).item()
|
| 1166 |
+
msg_type = IDX2MSG[mt_idx]
|
| 1167 |
+
ai_price = round(float(px.item()) * lp, 2)
|
| 1168 |
+
|
| 1169 |
+
# ββ Clamp AI price to valid range βββββββββββββββββββββββββ
|
| 1170 |
+
if ai_party_int == 0: # AI is seller
|
| 1171 |
+
ai_price = max(ai_price, s_res * 1.005)
|
| 1172 |
+
ai_price = min(ai_price, lp * 1.05)
|
| 1173 |
+
sp = ai_price
|
| 1174 |
+
else: # AI is buyer
|
| 1175 |
+
ai_price = min(ai_price, b_bud)
|
| 1176 |
+
ai_price = min(ai_price, sp * 0.99)
|
| 1177 |
+
ai_price = max(ai_price, lp * 0.25)
|
| 1178 |
+
bp = ai_price
|
| 1179 |
+
|
| 1180 |
+
# ββ Execute inventory search if triggered βββββββββββββββββ
|
| 1181 |
+
inv_context_text = ""
|
| 1182 |
+
if msg_type == "search":
|
| 1183 |
+
if ai_party_int == 0: # Seller searches for buyer
|
| 1184 |
+
results = search_inventory(
|
| 1185 |
+
category = category,
|
| 1186 |
+
max_price = b_bud if b_bud > 0 else lp * 1.1,
|
| 1187 |
+
min_price = b_est * 0.8 if b_est > 0 else 0,
|
| 1188 |
+
keywords = buyer_must_have,
|
| 1189 |
+
avoids = buyer_avoids,
|
| 1190 |
+
top_k = 3,
|
| 1191 |
+
)
|
| 1192 |
+
inv_context_text = format_inventory_context(
|
| 1193 |
+
results, reveal_floor=True
|
| 1194 |
+
)
|
| 1195 |
+
else: # Buyer searches seller inventory
|
| 1196 |
+
results = search_inventory(
|
| 1197 |
+
category = category,
|
| 1198 |
+
max_price = b_bud if b_bud > 0 else lp,
|
| 1199 |
+
keywords = buyer_must_have,
|
| 1200 |
+
avoids = buyer_avoids,
|
| 1201 |
+
top_k = 3,
|
| 1202 |
+
)
|
| 1203 |
+
inv_context_text = format_inventory_context(
|
| 1204 |
+
results, reveal_floor=False
|
| 1205 |
+
)
|
| 1206 |
+
session_state["inv_context"] = inv_context_text
|
| 1207 |
+
|
| 1208 |
+
# ββ Build AI message ββββββββββββββββββββββββββββββββββββββ
|
| 1209 |
+
ai_msg = _build_message(
|
| 1210 |
+
msg_type, ai_price, item,
|
| 1211 |
+
not is_user_buyer, ai_persona,
|
| 1212 |
+
inv_context_text
|
| 1213 |
+
)
|
| 1214 |
+
if msg_type == "search" and inv_context_text:
|
| 1215 |
+
ai_msg += (f"\n\nπ¦ **Inventory Results:**\n"
|
| 1216 |
+
f"```\n{inv_context_text}\n```")
|
| 1217 |
+
|
| 1218 |
+
history.append({
|
| 1219 |
+
"party": ai_party_int,
|
| 1220 |
+
"message": ai_msg,
|
| 1221 |
+
"offer_price": ai_price,
|
| 1222 |
+
"msg_type": msg_type,
|
| 1223 |
+
"turn_number": turn + 1,
|
| 1224 |
+
})
|
| 1225 |
+
ai_label = (f"π€ AI ({'Seller' if ai_party_int==0 else 'Buyer'}) "
|
| 1226 |
+
f"[{ai_persona}]")
|
| 1227 |
+
history_ui.append((None, f"{ai_label} [${ai_price:,.0f}]: {ai_msg}"))
|
| 1228 |
+
turn += 1
|
| 1229 |
|
| 1230 |
+
# ββ ZOPA ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1231 |
+
zopa = bp - s_res
|
| 1232 |
+
zopa_str = (f"β
ZOPA: +${zopa:,.0f} (deal zone exists)"
|
| 1233 |
+
if zopa > 0
|
| 1234 |
+
else f"β ZOPA: ${zopa:,.0f} (no overlap yet)")
|
| 1235 |
|
| 1236 |
+
# ββ Terminal check ββββββββββββββββββββββββββββββββββββββββ
|
| 1237 |
+
status = "active"
|
| 1238 |
+
if msg_type == "accept":
|
| 1239 |
+
status = "closed"
|
| 1240 |
+
history_ui.append(
|
| 1241 |
+
(None, f"β
**DEAL CLOSED at ${ai_price:,.0f}**")
|
| 1242 |
+
)
|
| 1243 |
+
elif msg_type == "exit":
|
| 1244 |
+
status = "ended"
|
| 1245 |
+
history_ui.append((None, "β Negotiation ended"))
|
| 1246 |
+
|
| 1247 |
+
probs = F.softmax(mt_logits, dim=1).squeeze().tolist()
|
| 1248 |
+
prob_str = " | ".join(
|
| 1249 |
+
f"{MSG_TYPES[i]}: {probs[i]:.2f}" for i in range(len(MSG_TYPES))
|
| 1250 |
+
)
|
| 1251 |
+
gap_pct = abs(sp - bp) / lp * 100
|
| 1252 |
+
summary = (f"Turn {turn} | Gap: {gap_pct:.1f}% | "
|
| 1253 |
+
f"Seller: ${sp:,.0f} | Buyer: ${bp:,.0f} | {zopa_str}")
|
| 1254 |
+
|
| 1255 |
+
session_state.update({
|
| 1256 |
+
"turn": turn,
|
| 1257 |
+
"sp": sp,
|
| 1258 |
+
"bp": bp,
|
| 1259 |
+
"history": history,
|
| 1260 |
+
"history_ui": history_ui,
|
| 1261 |
+
"status": status,
|
| 1262 |
+
})
|
| 1263 |
+
|
| 1264 |
+
return (session_state, history_ui, summary,
|
| 1265 |
+
msg_type, f"${ai_price:,.2f}", prob_str, inv_context_text)
|
| 1266 |
+
|
| 1267 |
+
|
| 1268 |
+
def reset_session():
|
| 1269 |
+
return {}, [], "Session reset.", "", "", "", ""
|
| 1270 |
+
|
| 1271 |
+
# ================================================================
|
| 1272 |
+
# STRATEGY GUIDES
|
| 1273 |
+
# ================================================================
|
| 1274 |
+
BUYER_GUIDE = """### π Buyer Playbook
|
| 1275 |
+
**Bounds to set before starting:**
|
| 1276 |
+
- **Budget** β your true ceiling. Encoded as soft penalty, not hard wall.
|
| 1277 |
+
- **Estimate** β fair value anchor. Sets your opening offer range.
|
| 1278 |
+
- **Must-have features** β filters inventory search. e.g. *bluetooth, low miles*
|
| 1279 |
+
- **Hard avoids** β instant deal-breakers. e.g. *salvage title, high mileage*
|
| 1280 |
+
|
| 1281 |
+
**Tactics the model trains on:**
|
| 1282 |
+
- π΄ Aggressive open at 55-65% of ask
|
| 1283 |
+
- πͺ Walk away at turn 3-4, return with prior offer
|
| 1284 |
+
- π Trigger search when gap > 12%: *"Do you have anything else?"*
|
| 1285 |
+
- β° Deadline pressure after patience threshold
|
| 1286 |
+
- πͺ Nibble for extras when gap < 8%
|
| 1287 |
+
- π€ Strategic persona: cite comps, build rapport"""
|
| 1288 |
+
|
| 1289 |
+
SELLER_GUIDE = """### π Seller Playbook
|
| 1290 |
+
**Bounds to set before starting:**
|
| 1291 |
+
- **Reservation price** β private floor. Model NEVER accepts below this.
|
| 1292 |
+
- **Urgency** β high urgency raises concession rate and search frequency.
|
| 1293 |
+
- **Inventory** β pre-loaded. Searched when buyer asks for alternatives.
|
| 1294 |
+
|
| 1295 |
+
**Tactics the model trains on:**
|
| 1296 |
+
- β Open 15-20% above target
|
| 1297 |
+
- π₯ Social proof: *"Two other buyers this weekend"*
|
| 1298 |
+
- π Proactively search inventory when buyer signals dissatisfaction
|
| 1299 |
+
- β° Urgency close: *"Need to close by Friday"*
|
| 1300 |
+
- π Return after walkaway with small concession
|
| 1301 |
+
- π Shrinking concessions signal approaching floor"""
|
| 1302 |
+
|
| 1303 |
+
# ================================================================
|
| 1304 |
+
# UI
|
| 1305 |
+
# ================================================================
|
| 1306 |
+
with gr.Blocks(title="ANP v5 | Bounded Negotiation",
|
| 1307 |
+
theme=gr.themes.Soft()) as demo:
|
| 1308 |
+
gr.Markdown(
|
| 1309 |
+
"# ANP v5 β Bounded Negotiation Engine\n"
|
| 1310 |
+
"Buyer bounds Β· Seller reservation Β· Inventory tool use Β· "
|
| 1311 |
+
"ZOPA tracking Β· Persona conditioning"
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
# ββ Training Tab ββββββββββββββββββββββββββββββββββββββββββ
|
| 1315 |
+
with gr.Tab("ποΈ Training"):
|
| 1316 |
with gr.Row():
|
| 1317 |
+
n_sessions = gr.Number(value=20000, label="Sessions")
|
| 1318 |
epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
|
| 1319 |
+
batch_size = gr.Slider(64, 1024, value=512, step=64,
|
| 1320 |
+
label="Batch Size")
|
| 1321 |
+
lr = gr.Number(value=3e-4, label="LR")
|
| 1322 |
+
tr_btn = gr.Button("π Train", variant="primary")
|
| 1323 |
+
status_box = gr.Textbox(label="Status", interactive=False,
|
| 1324 |
+
value="π΅ IDLE")
|
|
|
|
| 1325 |
with gr.Row():
|
| 1326 |
+
log_box = gr.Textbox(label="Logs", lines=14, interactive=False)
|
| 1327 |
plt_out = gr.Plot(label="Loss Curve")
|
| 1328 |
+
train_ready = gr.Textbox(visible=False)
|
| 1329 |
|
| 1330 |
+
# ββ Arena Tab βββββββββββββββββββββββββββββββββββββββββββββ
|
| 1331 |
+
with gr.Tab("π¬ Negotiation Arena"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1332 |
with gr.Row():
|
| 1333 |
+
|
| 1334 |
+
# Left panel β setup & analysis
|
| 1335 |
+
with gr.Column(scale=1):
|
| 1336 |
+
gr.Markdown("### βοΈ Session Setup")
|
| 1337 |
+
arena_cat = gr.Dropdown(
|
| 1338 |
+
CATEGORIES, value="used_car", label="Category"
|
| 1339 |
+
)
|
| 1340 |
+
arena_item = gr.Textbox(
|
| 1341 |
+
value="2019 Honda Civic", label="Item Name"
|
| 1342 |
+
)
|
| 1343 |
+
arena_lp = gr.Number(value=18500, label="List Price ($)")
|
| 1344 |
+
|
| 1345 |
+
with gr.Row():
|
| 1346 |
+
arena_user_pty = gr.Radio(
|
| 1347 |
+
["Buyer", "Seller"], value="Buyer", label="You are"
|
| 1348 |
+
)
|
| 1349 |
+
with gr.Row():
|
| 1350 |
+
arena_user_persona = gr.Dropdown(
|
| 1351 |
+
BUYER_PERSONAS, value="strategic",
|
| 1352 |
+
label="Your Persona"
|
| 1353 |
+
)
|
| 1354 |
+
arena_ai_persona = gr.Dropdown(
|
| 1355 |
+
SELLER_PERSONAS, value="firm",
|
| 1356 |
+
label="AI Persona"
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
gr.Markdown("---\n### π§ Buyer Bounds")
|
| 1360 |
+
buyer_budget = gr.Number(value=17000,
|
| 1361 |
+
label="Max Budget ($)")
|
| 1362 |
+
buyer_estimate = gr.Number(value=15500,
|
| 1363 |
+
label="Fair Value Estimate ($)")
|
| 1364 |
+
buyer_avoids = gr.Textbox(
|
| 1365 |
+
value="salvage,flood",
|
| 1366 |
+
label="Hard Avoids (comma list)"
|
| 1367 |
+
)
|
| 1368 |
+
buyer_must_have = gr.Textbox(
|
| 1369 |
+
value="bluetooth",
|
| 1370 |
+
label="Must-Have Features (comma list)"
|
| 1371 |
+
)
|
| 1372 |
+
|
| 1373 |
+
gr.Markdown("---\n### π€ Seller Bounds")
|
| 1374 |
+
seller_reservation = gr.Number(
|
| 1375 |
+
value=15000, label="Seller Floor / Reservation ($)"
|
| 1376 |
+
)
|
| 1377 |
+
seller_urgency = gr.Dropdown(
|
| 1378 |
+
["low", "medium", "high"], value="medium",
|
| 1379 |
+
label="Seller Urgency"
|
| 1380 |
+
)
|
| 1381 |
+
|
| 1382 |
+
reset_btn = gr.Button("π New Session", variant="secondary")
|
| 1383 |
+
|
| 1384 |
+
gr.Markdown("---\n### π Turn Analysis")
|
| 1385 |
+
arena_summary = gr.Textbox(
|
| 1386 |
+
label="Gap / ZOPA", interactive=False
|
| 1387 |
+
)
|
| 1388 |
+
arena_action = gr.Textbox(
|
| 1389 |
+
label="AI Action", interactive=False
|
| 1390 |
+
)
|
| 1391 |
+
arena_price = gr.Textbox(
|
| 1392 |
+
label="AI Price", interactive=False
|
| 1393 |
+
)
|
| 1394 |
+
arena_probs = gr.Textbox(
|
| 1395 |
+
label="Action Probabilities", interactive=False
|
| 1396 |
+
)
|
| 1397 |
+
inv_display = gr.Textbox(
|
| 1398 |
+
label="π Last Inventory Search",
|
| 1399 |
+
lines=5, interactive=False
|
| 1400 |
+
)
|
| 1401 |
+
|
| 1402 |
+
# Right panel β chat
|
| 1403 |
+
with gr.Column(scale=2):
|
| 1404 |
+
gr.Markdown("### π£οΈ Negotiation")
|
| 1405 |
+
chatbot = gr.Chatbot(height=520, label="Conversation")
|
| 1406 |
+
with gr.Row():
|
| 1407 |
+
arena_offer = gr.Number(value=16000,
|
| 1408 |
+
label="Your Offer ($)")
|
| 1409 |
+
arena_msg = gr.Textbox(
|
| 1410 |
+
placeholder="Type your message...",
|
| 1411 |
+
label="Your Message", scale=3
|
| 1412 |
+
)
|
| 1413 |
+
send_btn = gr.Button("Send β", variant="primary")
|
| 1414 |
+
|
| 1415 |
+
# ββ Strategy Guides Tab βββββββββββββββββββββββββββββββββββ
|
| 1416 |
+
with gr.Tab("π Playbooks"):
|
| 1417 |
with gr.Row():
|
| 1418 |
+
gr.Markdown(BUYER_GUIDE)
|
| 1419 |
+
gr.Markdown(SELLER_GUIDE)
|
| 1420 |
+
|
| 1421 |
+
# ββ Inventory Browser Tab βββββββββββββββββββββββββββββββββ
|
| 1422 |
+
with gr.Tab("π¦ Inventory"):
|
| 1423 |
+
gr.Markdown(
|
| 1424 |
+
"### Current Inventory Database\n"
|
| 1425 |
+
"Plain text rows β term-frequency search, no vectors at rest."
|
| 1426 |
+
)
|
| 1427 |
+
inv_text = "\n".join(
|
| 1428 |
+
f"[{it['id']}] {it['category']} | {it['name']} | "
|
| 1429 |
+
f"{it['condition']} | Ask: ${it['ask_price']:,} | "
|
| 1430 |
+
f"Features: {it['features']}"
|
| 1431 |
+
for it in INVENTORY
|
| 1432 |
+
)
|
| 1433 |
+
gr.Textbox(
|
| 1434 |
+
value=inv_text, lines=30, interactive=False,
|
| 1435 |
+
label="Inventory (floor hidden from buyer-facing searches)"
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
# ββ State βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1439 |
+
session_state = gr.State({})
|
| 1440 |
+
|
| 1441 |
+
def update_personas(party):
|
| 1442 |
+
if party == "Buyer":
|
| 1443 |
+
return (
|
| 1444 |
+
gr.Dropdown(choices=BUYER_PERSONAS, value="strategic"),
|
| 1445 |
+
gr.Dropdown(choices=SELLER_PERSONAS, value="firm"),
|
| 1446 |
+
)
|
| 1447 |
+
return (
|
| 1448 |
+
gr.Dropdown(choices=SELLER_PERSONAS, value="firm"),
|
| 1449 |
+
gr.Dropdown(choices=BUYER_PERSONAS, value="strategic"),
|
| 1450 |
+
)
|
| 1451 |
+
|
| 1452 |
+
arena_user_pty.change(
|
| 1453 |
+
update_personas,
|
| 1454 |
+
inputs=[arena_user_pty],
|
| 1455 |
+
outputs=[arena_user_persona, arena_ai_persona]
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
tr_btn.click(
|
| 1459 |
+
run_training,
|
| 1460 |
+
inputs=[n_sessions, epochs, batch_size, lr],
|
| 1461 |
+
outputs=[status_box, log_box, plt_out, train_ready]
|
| 1462 |
+
)
|
| 1463 |
+
|
| 1464 |
+
send_btn.click(
|
| 1465 |
+
run_inference_turn,
|
| 1466 |
+
inputs=[
|
| 1467 |
+
session_state,
|
| 1468 |
+
arena_cat, arena_item, arena_lp,
|
| 1469 |
+
arena_offer, arena_msg,
|
| 1470 |
+
arena_user_pty, arena_user_persona, arena_ai_persona,
|
| 1471 |
+
buyer_budget, buyer_estimate,
|
| 1472 |
+
buyer_avoids, buyer_must_have,
|
| 1473 |
+
seller_reservation, seller_urgency,
|
| 1474 |
+
],
|
| 1475 |
+
outputs=[
|
| 1476 |
+
session_state, chatbot, arena_summary,
|
| 1477 |
+
arena_action, arena_price, arena_probs, inv_display,
|
| 1478 |
+
]
|
| 1479 |
+
)
|
| 1480 |
+
|
| 1481 |
+
reset_btn.click(
|
| 1482 |
+
reset_session,
|
| 1483 |
+
outputs=[
|
| 1484 |
+
session_state, chatbot, arena_summary,
|
| 1485 |
+
arena_action, arena_price, arena_probs, inv_display,
|
| 1486 |
+
]
|
| 1487 |
+
)
|
| 1488 |
+
|
| 1489 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|