everydaytok commited on
Commit
f070b57
Β·
verified Β·
1 Parent(s): 8376a9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1399 -328
app.py CHANGED
@@ -1,13 +1,15 @@
1
  # ================================================================
2
- # ANP Model | HF Free Tier (16GB CPU) | Background Training Daemon
 
 
3
  # ================================================================
4
- import os, time, math, random, uuid, threading
5
- from typing import List, Dict
6
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
- from torch.utils.data import Dataset, DataLoader
11
  from torch.optim import AdamW
12
  from torch.optim.lr_scheduler import CosineAnnealingLR
13
  from transformers import BertTokenizerFast
@@ -20,399 +22,1468 @@ import matplotlib.pyplot as plt
20
  random.seed(42)
21
  torch.manual_seed(42)
22
 
23
- # ── Config & Globals ──────────────────────────────────────────
24
- DEVICE = torch.device("cpu") # HF Free tier is CPU
25
- MSG_TYPES = ["offer", "counter", "accept", "reject", "exit", "stall"]
 
26
  MSG2IDX = {m: i for i, m in enumerate(MSG_TYPES)}
27
  IDX2MSG = {i: m for m, i in MSG2IDX.items()}
28
- CATEGORIES = ["used_car","domain_name","freelance_design","saas_license","electronics","bulk_groceries","consulting"]
 
 
29
  CAT2IDX = {c: i for i, c in enumerate(CATEGORIES)}
30
 
31
- MAX_LEN = 256
 
 
 
 
 
32
  D_MODEL = 384
33
  N_HEADS = 6
34
  N_LAYERS = 6
35
  FFN_DIM = 1024
36
 
37
- print("Loading tokenizer...")
 
 
 
 
38
  tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
 
39
 
40
- # ── Thread-Safe State Manager ─────────────────────────────────
41
- class TrainingState:
42
- def __init__(self):
43
- self.lock = threading.Lock()
44
- self.is_running = False
45
- self.current_epoch = 0
46
- self.total_epochs = 0
47
- self.batch_progress = ""
48
- self.logs = []
49
- self.losses = []
50
- self.model_ready = False
51
-
52
- def log(self, msg: str):
53
- with self.lock:
54
- ts = time.strftime("%H:%M:%S")
55
- self.logs.append(f"[{ts}] {msg}")
56
- if len(self.logs) > 50: # Keep dashboard clean
57
- self.logs.pop(0)
58
- print(msg)
59
-
60
- STATE = TrainingState()
61
- GLOBAL_MODEL = None # Holds the model in memory for inference
62
-
63
- # ── Synthetic Data Generator ──────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def generate_sessions(n_sessions: int) -> List[Dict]:
65
- """Generates synthetic negotiation data quickly in memory."""
66
  all_rows = []
67
- cats = list(CATEGORIES)
68
-
69
- # Simple templates for generator (training text)
70
- _SO = ["{item} for sale. Asking ${p:,.0f}.", "Listing {item} at ${p:,.0f}."]
71
- _SC = ["Best I can do is ${p:,.0f}.", "Can come down to ${p:,.0f}."]
72
- _SS = ["Let me think about it.", "Need to check with my partner."]
73
- _SA = ["Deal at ${p:,.0f}.", "Agreed. ${p:,.0f}."]
74
- _BC = ["Offering ${p:,.0f}.", "${p:,.0f} is my ceiling."]
75
- _BE = ["Too far apart. Going to pass.", "Price doesn't work for me."]
76
-
77
- def _t(templates, item="", p=0):
78
- return random.choice(templates).format(item=item, p=p)
79
-
80
- for _ in range(n_sessions):
81
- cat = random.choice(cats)
82
- item = f"Generic {cat} Item"
83
- lp = round(random.uniform(500, 10000), -1)
84
- sid = f"SYN-{uuid.uuid4().hex[:6].upper()}"
85
- turn = 0
86
- session_rows = []
87
 
88
  def add(party, price, mtype, msg):
89
  nonlocal turn
90
  turn += 1
91
- session_rows.append({
92
- "session_id": sid, "turn_number": turn, "party": party,
93
- "category": cat, "item": item, "list_price": lp,
94
- "offer_price": price, "msg_type": mtype, "message": msg
 
 
 
 
 
 
 
 
 
 
 
95
  })
96
 
97
  sp = lp
98
- bp = round(lp * random.uniform(0.6, 0.8), -1)
99
-
100
- add(0, sp, "offer", _t(_SO, item=item, p=sp))
101
- add(1, bp, "counter", _t(_BC, p=bp))
102
-
103
- target = random.choice(["accepted", "abandoned", "rejected"])
104
- for _ in range(random.randint(2, 6)):
105
- gap = sp - bp
106
- if target == "accepted" and (gap / lp) < 0.05:
107
- final_p = round((sp + bp) / 2, -1)
108
- add(0 if random.random() < 0.5 else 1, final_p, "accept", _t(_SA, p=final_p))
109
- break
110
- if target == "abandoned" and random.random() < 0.2:
111
- add(0, sp, "stall", _t(_SS))
112
- add(1, bp, "exit", _t(_BE))
113
- break
114
-
115
- sp = max(bp + gap * 0.3, sp - lp * random.uniform(0.02, 0.05))
116
- sp = round(sp, -1)
117
- add(0, sp, "counter", _t(_SC, p=sp))
118
-
119
- gap = sp - bp
120
- if target == "accepted" and (gap / lp) < 0.05:
121
- final_p = round((sp + bp) / 2, -1)
122
- add(1, final_p, "accept", _t(_SA, p=final_p))
123
- break
124
-
125
- bp = min(sp - gap * 0.3, bp + lp * random.uniform(0.02, 0.05))
126
- bp = round(bp, -1)
127
- add(1, bp, "counter", _t(_BC, p=bp))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  else:
129
- add(1, bp, "exit", _t(_BE))
130
-
131
- all_rows.extend(session_rows)
 
 
 
 
 
 
132
  return all_rows
133
 
134
- # ── Dataset & Model ───────────────────────────────────────────
135
- class NegotiationDataset(Dataset):
136
- def __init__(self, rows: List[Dict]):
137
- self.samples = []
138
- sessions = {}
139
- for r in rows:
140
- sessions.setdefault(r["session_id"], []).append(r)
141
-
142
- for turns in sessions.values():
143
- turns = sorted(turns, key=lambda x: int(x["turn_number"]))
144
- lp = float(turns[0]["list_price"])
145
- if lp <= 0: continue
146
-
147
- for i in range(1, len(turns)):
148
- hist = turns[:i]
149
- tgt = turns[i]
150
- text = " [SEP] ".join(f"{'Seller' if t['party']==0 else 'Buyer'}: {t['message']}" for t in hist)
151
- self.samples.append({
152
- "text": text,
153
- "party": int(tgt["party"]),
154
- "category": CAT2IDX.get(tgt["category"], 0),
155
- "ofn": min(float(tgt["offer_price"]) / lp, 3.0),
156
- "tn": min(int(tgt["turn_number"]) / 20.0, 1.0),
157
- "msg_type": MSG2IDX.get(tgt["msg_type"], 1),
158
- "price_t": min(float(tgt["offer_price"]) / lp, 3.0),
159
- })
160
-
161
- def __len__(self): return len(self.samples)
162
- def __getitem__(self, idx):
163
- s = self.samples[idx]
164
- enc = tokenizer(s["text"], max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt")
165
- return {
166
- "input_ids": enc["input_ids"].squeeze(0),
167
- "attn_mask": enc["attention_mask"].squeeze(0),
168
- "party": torch.tensor(s["party"], dtype=torch.long),
169
- "category": torch.tensor(s["category"], dtype=torch.long),
170
- "ofn": torch.tensor(s["ofn"], dtype=torch.float),
171
- "tn": torch.tensor(s["tn"], dtype=torch.float),
172
- "msg_type": torch.tensor(s["msg_type"], dtype=torch.long),
173
- "price_t": torch.tensor(s["price_t"], dtype=torch.float),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  }
175
 
 
 
 
 
 
176
  class PositionalEncoding(nn.Module):
177
  def __init__(self, d: int, max_len: int = 512):
178
  super().__init__()
179
  self.drop = nn.Dropout(0.1)
180
- pe = torch.zeros(max_len, d)
181
  pos = torch.arange(max_len).unsqueeze(1).float()
182
- div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d))
 
 
183
  pe[:, 0::2] = torch.sin(pos * div)
184
  pe[:, 1::2] = torch.cos(pos * div)
185
  self.register_buffer("pe", pe.unsqueeze(0))
186
 
187
- def forward(self, x): return self.drop(x + self.pe[:, :x.size(1)])
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  class NegotiationTransformer(nn.Module):
190
  def __init__(self):
191
  super().__init__()
192
- self.emb = nn.Embedding(30522, D_MODEL, padding_idx=0)
193
- self.pos = PositionalEncoding(D_MODEL)
194
- enc_layer = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, FFN_DIM, dropout=0.1, batch_first=True, norm_first=True)
195
- self.encoder = nn.TransformerEncoder(enc_layer, N_LAYERS)
196
- self.p_emb = nn.Embedding(2, 32)
197
- self.c_emb = nn.Embedding(len(CATEGORIES), 64)
198
- self.fusion = nn.Sequential(nn.Linear(D_MODEL + 32 + 64 + 2, D_MODEL), nn.GELU())
 
 
 
 
 
 
 
 
 
199
  self.msg_head = nn.Linear(D_MODEL, len(MSG_TYPES))
200
- self.px_head = nn.Sequential(nn.Linear(D_MODEL, 128), nn.GELU(), nn.Linear(128, 1), nn.Softplus())
 
 
 
201
 
202
- def forward(self, ids, mask, party, cat, ofn, tn):
203
- x = self.pos(self.emb(ids))
204
- x = self.encoder(x, src_key_padding_mask=(mask == 0))
205
  cls = x[:, 0]
206
- f = self.fusion(torch.cat([cls, self.p_emb(party), self.c_emb(cat), torch.stack([ofn, tn], dim=1)], dim=1))
 
 
 
 
 
 
 
 
 
207
  return self.msg_head(f), self.px_head(f).squeeze(1)
208
 
209
- # ── Background Training Daemon ────────────────────────────────
210
- def _training_thread_target(n_sessions: int, epochs: int, batch_size: int, lr: float):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  global GLOBAL_MODEL
 
 
 
 
 
 
 
 
 
 
 
212
  try:
213
- STATE.log(f"Starting data generation: {n_sessions:,} sessions (~{n_sessions*5:,} rows)")
214
-
215
- # Generation runs in main memory, yields CPU often enough
216
- rows = generate_sessions(n_sessions)
217
- STATE.log(f"Data generated. Tokenizing into dataset...")
218
-
219
- dataset = NegotiationDataset(rows)
220
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
221
-
222
- STATE.log(f"Dataset ready: {len(dataset):,} samples. Initializing Model...")
223
- model = NegotiationTransformer().to(DEVICE)
224
-
225
- opt = AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
226
- sch = CosineAnnealingLR(opt, T_max=epochs)
227
- ce, mse = nn.CrossEntropyLoss(), nn.MSELoss()
228
-
229
- with STATE.lock:
230
- STATE.total_epochs = epochs
231
- STATE.losses = []
232
-
233
- STATE.log("Entering Training Loop (CPU mode).")
234
  total_batches = len(loader)
235
-
236
- for ep in range(epochs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  model.train()
238
  ep_loss = 0.0
239
- with STATE.lock:
240
- STATE.current_epoch = ep + 1
241
-
242
  for i, batch in enumerate(loader):
243
- if i % max(1, total_batches // 10) == 0:
244
- with STATE.lock:
245
- STATE.batch_progress = f"Epoch {ep+1}/{epochs} | Batch {i}/{total_batches}"
246
-
247
- opt.zero_grad()
248
- mt_logits, px_pred = model(
249
- batch["input_ids"].to(DEVICE), batch["attn_mask"].to(DEVICE),
250
- batch["party"].to(DEVICE), batch["category"].to(DEVICE),
251
- batch["ofn"].to(DEVICE), batch["tn"].to(DEVICE)
252
- )
253
- loss = ce(mt_logits, batch["msg_type"].to(DEVICE)) + 0.5 * mse(px_pred, batch["price_t"].to(DEVICE))
254
- loss.backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  nn.utils.clip_grad_norm_(model.parameters(), 1.0)
256
- opt.step()
 
257
  ep_loss += loss.item()
258
-
259
  sch.step()
260
- avg_loss = ep_loss / max(total_batches, 1)
261
-
262
- with STATE.lock:
263
- STATE.losses.append(avg_loss)
264
- STATE.log(f"Epoch {ep+1} complete. Loss: {avg_loss:.4f}")
 
 
 
 
 
265
 
266
- STATE.log("Training complete. Applying weights to Global Model.")
267
  model.eval()
268
  GLOBAL_MODEL = model
269
- with STATE.lock:
270
- STATE.model_ready = True
271
-
272
  except Exception as e:
273
- STATE.log(f"ERROR: {str(e)}")
274
- finally:
275
- with STATE.lock:
276
- STATE.is_running = False
277
-
278
- def start_training(n_sessions, epochs, batch_size, lr):
279
- with STATE.lock:
280
- if STATE.is_running:
281
- return "Training is already running!"
282
- STATE.is_running = True
283
- STATE.logs = []
284
- STATE.batch_progress = "Initializing..."
285
-
286
- t = threading.Thread(target=_training_thread_target, args=(int(n_sessions), int(epochs), int(batch_size), float(lr)), daemon=True)
287
- t.start()
288
- return "Background training thread triggered."
289
-
290
- # ── Inference with Pre-built Templates ────────────────────────
291
- def _get_template_message(msg_type: str, price: float, item: str, is_buyer: bool) -> str:
292
- """The 'Mouth': Translates the Model's strategy (msg_type, price) into prose."""
293
- px = f"${price:,.2f}"
294
  if is_buyer:
295
- templates = {
296
- "offer": f"I'll start the bidding at {px} for the {item}.",
297
- "counter": random.choice([f"I can offer {px}.", f"How about {px}?", f"My counter is {px}."]),
298
- "accept": f"{px} works for me. I'll take it.",
299
- "reject": "That's too high for my budget, I have to pass.",
300
- "stall": "I need to check my budget and get back to you.",
301
- "exit": "We're too far apart. Moving on."
302
  }
 
 
 
303
  else:
304
- templates = {
305
- "offer": f"I'm looking to get {px} for the {item}.",
306
- "counter": random.choice([f"I can drop to {px}.", f"Best I can do right now is {px}.", f"Let's meet at {px}."]),
307
- "accept": f"You got a deal at {px}.",
308
- "reject": "I can't go that low.",
309
- "stall": "Let me see if I have other offers first.",
310
- "exit": "I can't sell it for that. Goodbye."
311
  }
312
- return templates.get(msg_type, f"Action: {msg_type} at {px}")
 
 
 
 
 
 
 
 
 
313
 
314
- def predict(category, item, list_price, current_offer, history_text, party_str):
 
 
 
 
 
 
 
 
315
  if GLOBAL_MODEL is None:
316
- return "Model not trained yet. Run training tab first.", "", "", ""
317
-
318
- try:
319
- lp, cp = float(list_price), float(current_offer)
320
- is_buyer = (party_str == "Buyer")
321
- pty = 1 if is_buyer else 0
322
-
323
- enc = tokenizer(history_text or "(start)", max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt")
324
- turns = len([l for l in history_text.strip().split("\n") if l.strip()])
325
-
326
- p = torch.tensor([pty], dtype=torch.long)
327
- c = torch.tensor([CAT2IDX.get(category, 0)], dtype=torch.long)
328
- ofn = torch.tensor([min(cp/lp, 3.0)], dtype=torch.float)
329
- tn = torch.tensor([min(turns/20.0, 1.0)], dtype=torch.float)
330
-
331
- with torch.no_grad():
332
- mt_logits, px = GLOBAL_MODEL(enc["input_ids"], enc["attention_mask"], p, c, ofn, tn)
333
-
334
- mt_idx = mt_logits.argmax(dim=1).item()
335
- msg_out = IDX2MSG[mt_idx]
336
- price_out = round(px.item() * lp, 2)
337
-
338
- prose_msg = _get_template_message(msg_out, price_out, item, is_buyer)
339
- probs = F.softmax(mt_logits, dim=1).squeeze().tolist()
340
- prob_str = " | ".join(f"{MSG_TYPES[i]}: {probs[i]:.2f}" for i in range(len(MSG_TYPES)))
341
-
342
- return msg_out, f"${price_out:,.2f}", prose_msg, prob_str
343
- except Exception as e:
344
- return "Error", "", str(e), ""
345
-
346
- # ── Dashboard UI (Polling) ────────────────────────────────────
347
- def refresh_dashboard():
348
- with STATE.lock:
349
- is_run = STATE.is_running
350
- status = "🟒 ACTIVE - " + STATE.batch_progress if is_run else "πŸ”΄ IDLE"
351
- log_text = "\n".join(STATE.logs)
352
- losses = list(STATE.losses)
353
- ready = "βœ… Ready" if STATE.model_ready else "❌ Needs Training"
354
-
355
- fig, ax = plt.subplots(figsize=(6, 3))
356
- if losses:
357
- ax.plot(range(1, len(losses)+1), losses, "b-o", markersize=4)
358
- ax.set_title("Training Loss")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  else:
360
- ax.text(0.5, 0.5, 'No data yet', ha='center', va='center', alpha=0.5)
361
- ax.grid(alpha=0.3)
362
- plt.tight_layout()
363
-
364
- return status, log_text, fig, ready
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
- # ── Gradio ────────────────────────────────────────────────────
367
- with gr.Blocks(title="ANP | HF Daemon Trainer", theme=gr.themes.Soft()) as demo:
368
- gr.Markdown("# ANP Background Trainer\nTrains on the HF free CPU via a background thread while you watch.")
 
 
369
 
370
- with gr.Tab("Dashboard & Training"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  with gr.Row():
372
- n_sessions = gr.Number(value=40000, label="Sessions (~200k rows)")
373
  epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
374
- batch_size = gr.Slider(16, 256, value=64, step=16, label="Batch Size")
375
- lr = gr.Number(value=5e-4, label="Learning Rate")
376
-
377
- tr_btn = gr.Button("πŸš€ Start Background Training", variant="primary")
378
-
379
- gr.Markdown("### Real-Time Status *(Polls automatically)*")
380
- status_box = gr.Textbox(label="Thread Status", interactive=False)
381
  with gr.Row():
382
- log_box = gr.Textbox(label="System Logs", lines=12, interactive=False)
383
  plt_out = gr.Plot(label="Loss Curve")
 
384
 
385
- # Gradio Timer continuously updates the dashboard every 3 seconds
386
- gr.Timer(3, active=True).tick(
387
- fn=refresh_dashboard,
388
- outputs=[status_box, log_box, plt_out, gr.Textbox(visible=False)]
389
- )
390
-
391
- with gr.Tab("Inference Sandbox"):
392
- ready_indicator = gr.Textbox(label="Model Status", interactive=False)
393
- gr.Timer(5, active=True).tick(fn=lambda: "βœ… Ready" if STATE.model_ready else "❌ Needs Training", outputs=[ready_indicator])
394
-
395
- with gr.Row():
396
- d_cat = gr.Dropdown(CATEGORIES, value="used_car", label="Category")
397
- d_pty = gr.Radio(["Seller","Buyer"], value="Buyer", label="Party to Simulate")
398
  with gr.Row():
399
- d_lp = gr.Number(value=18500, label="List Price ($)")
400
- d_co = gr.Number(value=16000, label="Current Offer ($)")
401
-
402
- d_item = gr.Textbox(value="2019 Honda Civic", label="Item Name (for template)")
403
- d_hist = gr.Textbox(lines=4, label="Turn History", placeholder="Seller: Asking $18,500.\nBuyer: I can do $15,000.")
404
-
405
- d_btn = gr.Button("Generate Move & Message", variant="primary")
406
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  with gr.Row():
408
- d_msg = gr.Textbox(label="Action Head")
409
- d_px = gr.Textbox(label="Pricing Head")
410
- d_prose = gr.Textbox(label="Generated Message (Template)", lines=2)
411
- d_prob = gr.Textbox(label="Action Probabilities")
412
-
413
- d_btn.click(predict, inputs=[d_cat, d_item, d_lp, d_co, d_hist, d_pty], outputs=[d_msg, d_px, d_prose, d_prob])
414
-
415
- tr_btn.click(start_training, inputs=[n_sessions, epochs, batch_size, lr], outputs=[status_box])
416
-
417
- # Launch blocking the main thread, daemons will run in background
418
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ================================================================
2
+ # ANP v5 | Bounded Multi-Agent Negotiation + Inventory Tool Use
3
+ # Buyer bounds Β· Seller inventory context Β· Search action head
4
+ # ZOPA tracking Β· Reservation prices Β· Ranked inventory matching
5
  # ================================================================
6
+ import os, time, math, random, uuid, gc
7
+ from typing import List, Dict, Tuple, Optional
8
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader, TensorDataset
13
  from torch.optim import AdamW
14
  from torch.optim.lr_scheduler import CosineAnnealingLR
15
  from transformers import BertTokenizerFast
 
22
  random.seed(42)
23
  torch.manual_seed(42)
24
 
25
+ # ── Config ────────────────────────────────────────────────────
26
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ MSG_TYPES = ["offer","counter","accept","reject","exit","stall","search"]
29
  MSG2IDX = {m: i for i, m in enumerate(MSG_TYPES)}
30
  IDX2MSG = {i: m for m, i in MSG2IDX.items()}
31
+
32
+ CATEGORIES = ["used_car","domain_name","freelance_design","saas_license",
33
+ "electronics","bulk_groceries","consulting"]
34
  CAT2IDX = {c: i for i, c in enumerate(CATEGORIES)}
35
 
36
+ BUYER_PERSONAS = ["aggressive","patient","skeptical","impulsive","strategic"]
37
+ SELLER_PERSONAS = ["firm","motivated","anchoring","collaborative","desperate"]
38
+ BPERSONA2IDX = {p: i for i, p in enumerate(BUYER_PERSONAS)}
39
+ SPERSONA2IDX = {p: i for i, p in enumerate(SELLER_PERSONAS)}
40
+
41
+ MAX_LEN = 96
42
  D_MODEL = 384
43
  N_HEADS = 6
44
  N_LAYERS = 6
45
  FFN_DIM = 1024
46
 
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+ torch.backends.cudnn.allow_tf32 = True
49
+ torch.backends.cudnn.benchmark = True
50
+
51
+ print(f"Device: {DEVICE}")
52
  tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
53
+ GLOBAL_MODEL = None
54
 
55
+ # ================================================================
56
+ # INVENTORY DATABASE
57
+ # ================================================================
58
+ def _make_inventory() -> List[Dict]:
59
+ inv = []
60
+ templates = {
61
+ "used_car": [
62
+ ("2018 Toyota Camry", "Good", 14500, 12800,
63
+ "sunroof,bluetooth,low miles"),
64
+ ("2019 Honda Civic", "Excellent", 18500, 16200,
65
+ "one owner,new tires,clean title"),
66
+ ("2020 Ford F-150", "Good", 28000, 24500,
67
+ "tow package,crew cab,4WD"),
68
+ ("2016 BMW 3 Series", "Fair", 16000, 13500,
69
+ "sport package,leather,sunroof"),
70
+ ("2021 Tesla Model 3", "Excellent", 38000, 35000,
71
+ "autopilot,long range,premium audio"),
72
+ ("2017 Chevy Silverado", "Good", 22000, 19000,
73
+ "4WD,tow hitch,extended cab"),
74
+ ("2015 Honda Accord", "Fair", 11000, 9200,
75
+ "2 owners,new brakes,cloth seats"),
76
+ ("2022 Toyota RAV4", "Excellent", 32000, 29500,
77
+ "hybrid,AWD,apple carplay"),
78
+ ],
79
+ "electronics": [
80
+ ("MacBook Pro 14 M2", "Excellent", 1800, 1600,
81
+ "16GB RAM,512GB SSD,AppleCare"),
82
+ ("iPhone 14 Pro", "Good", 900, 780,
83
+ "256GB,space black,minor scratches"),
84
+ ("Sony 65in 4K TV", "Excellent", 750, 620,
85
+ "OLED,smart tv,2 years old"),
86
+ ("iPad Air Gen5", "Good", 550, 470,
87
+ "wifi+cellular,pencil included"),
88
+ ("Gaming PC RTX4070", "Excellent", 1400, 1200,
89
+ "32GB RAM,1TB NVMe,water cooled"),
90
+ ("DJI Mavic 3", "Good", 900, 780,
91
+ "4K camera,3 batteries,case"),
92
+ ],
93
+ "domain_name": [
94
+ ("QuickLoan.io", "Premium", 12000, 9500,
95
+ "fintech,4 years aged,high DA"),
96
+ ("FreshMeals.com", "Good", 4500, 3800,
97
+ "food delivery niche,aged 6yr"),
98
+ ("TechPulse.net", "Good", 2200, 1800,
99
+ "tech blog ready,clean history"),
100
+ ("GreenHome.co", "Excellent", 5500, 4600,
101
+ "eco niche,brandable,short"),
102
+ ("RapidShip.io", "Premium", 8000, 6800,
103
+ "logistics niche,exact match"),
104
+ ],
105
+ "freelance_design": [
106
+ ("Logo + Brand Kit", "Standard", 800, 650,
107
+ "5 concepts,unlimited revisions,source files"),
108
+ ("Website Redesign", "Premium", 3500, 2800,
109
+ "5 pages,mobile,figma handoff"),
110
+ ("UI/UX App Design", "Premium", 5000, 4200,
111
+ "full wireframes,prototype,design system"),
112
+ ("Social Media Pack", "Standard", 600, 480,
113
+ "30 templates,brand colors,canva ready"),
114
+ ("Pitch Deck Design", "Standard", 1200, 950,
115
+ "20 slides,animations,2 revisions"),
116
+ ],
117
+ "saas_license": [
118
+ ("CRM Pro Annual", "Standard", 2400, 1900,
119
+ "unlimited users,API access,support"),
120
+ ("Analytics Suite", "Premium", 4800, 3900,
121
+ "real-time,custom dashboards,export"),
122
+ ("Project Mgmt Tool", "Standard", 1200, 980,
123
+ "50 users,gantt,integrations"),
124
+ ("Email Marketing Pro", "Standard", 960, 780,
125
+ "100k contacts,automation,A/B"),
126
+ ],
127
+ "bulk_groceries": [
128
+ ("Organic Coffee 50lb", "Fresh", 420, 350,
129
+ "single origin,roasted weekly,wholesale"),
130
+ ("Olive Oil 5 Gal", "Premium", 280, 230,
131
+ "extra virgin,cold press,Italian"),
132
+ ("Almond Flour 25lb", "Fresh", 180, 145,
133
+ "blanched,gluten free,bulk"),
134
+ ("Protein Powder 20lb", "Good", 260, 210,
135
+ "whey isolate,unflavored,NSF cert"),
136
+ ],
137
+ "consulting": [
138
+ ("SEO Audit + 90 Day Plan", "Standard", 1500, 1200,
139
+ "technical+content,keyword research,monthly report"),
140
+ ("Financial Model Build", "Premium", 3500, 2900,
141
+ "3 statement,DCF,scenario analysis"),
142
+ ("HR Policy Package", "Standard", 1800, 1450,
143
+ "employee handbook,policies,compliance"),
144
+ ("Marketing Strategy Q", "Premium", 4200, 3500,
145
+ "market research,ICP,channel plan"),
146
+ ],
147
+ }
148
+ for cat, items in templates.items():
149
+ for (name, cond, ask, res, feats) in items:
150
+ inv.append({
151
+ "id": str(uuid.uuid4().hex[:8]),
152
+ "category": cat,
153
+ "name": name,
154
+ "condition": cond,
155
+ "ask_price": ask,
156
+ "reservation_price": res,
157
+ "features": feats,
158
+ "notes": "",
159
+ })
160
+ return inv
161
+
162
+ INVENTORY: List[Dict] = _make_inventory()
163
+
164
+ def search_inventory(
165
+ category: str,
166
+ max_price: float,
167
+ min_price: float = 0,
168
+ keywords: str = "",
169
+ top_k: int = 4,
170
+ avoids: str = "",
171
+ ) -> List[Dict]:
172
+ kws = [k.strip().lower() for k in keywords.split(",") if k.strip()]
173
+ avd = [a.strip().lower() for a in avoids.split(",") if a.strip()]
174
+ results = []
175
+ for item in INVENTORY:
176
+ if item["category"] != category:
177
+ continue
178
+ if item["ask_price"] > max_price * 1.15:
179
+ continue
180
+ if item["ask_price"] < min_price:
181
+ continue
182
+ combined = f"{item['name']} {item['features']} {item['notes']}".lower()
183
+ if any(av in combined for av in avd):
184
+ continue
185
+ kw_score = sum(1 for kw in kws if kw in combined)
186
+ mid = ((max_price + min_price) / 2
187
+ if min_price > 0 else max_price * 0.8)
188
+ price_dist = abs(item["ask_price"] - mid) / max(mid, 1)
189
+ score = kw_score * 2 - price_dist
190
+ results.append({**item, "_score": score})
191
+ results.sort(key=lambda x: x["_score"], reverse=True)
192
+ return results[:top_k]
193
+
194
+ def format_inventory_context(
195
+ items: List[Dict], reveal_floor: bool = False
196
+ ) -> str:
197
+ if not items:
198
+ return "No matching inventory found."
199
+ lines = []
200
+ for it in items:
201
+ line = (f"[{it['id']}] {it['name']} | {it['condition']} | "
202
+ f"Ask: ${it['ask_price']:,} | Features: {it['features']}")
203
+ if reveal_floor:
204
+ line += f" | Floor: ${it['reservation_price']:,}"
205
+ lines.append(line)
206
+ return "\n".join(lines)
207
+
208
+ # ================================================================
209
+ # TEMPLATES
210
+ # ================================================================
211
+ TEMPLATES = {
212
+ "seller_open_firm": [
213
+ "I've had this {item} listed and I'm firm at ${p:,.0f}. "
214
+ "It's priced fairly for the condition.",
215
+ "The market supports ${p:,.0f} for a {item} like this. "
216
+ "I've done my research.",
217
+ "Asking ${p:,.0f} for the {item}. I'm not in a rush β€” "
218
+ "prefer not to negotiate far from that.",
219
+ ],
220
+ "seller_open_motivated": [
221
+ "I'm listing the {item} at ${p:,.0f} but open to reasonable offers. "
222
+ "I'd like to move this quickly.",
223
+ "Got this {item} up for ${p:,.0f}. "
224
+ "Motivated to sell β€” make me an offer.",
225
+ "Selling the {item} at ${p:,.0f}. "
226
+ "I have flexibility if you're serious about buying today.",
227
+ ],
228
+ "seller_counter_hold": [
229
+ "I appreciate the offer but I can't go below ${p:,.0f}. "
230
+ "That's really my floor.",
231
+ "I hear you, but ${p:,.0f} is already a stretch. "
232
+ "I have other interested buyers closer to asking.",
233
+ "That doesn't quite work. I could come to ${p:,.0f} "
234
+ "but that's genuinely as low as I go.",
235
+ ],
236
+ "seller_counter_concede": [
237
+ "Alright, I can meet you a bit closer β€” how does ${p:,.0f} sound?",
238
+ "I've thought about it and I can work with ${p:,.0f} "
239
+ "if we can close today.",
240
+ "Let me split the difference with you. ${p:,.0f} β€” fair?",
241
+ ],
242
+ "seller_stall": [
243
+ "Let me think on that overnight. "
244
+ "I want to make sure I'm not leaving too much on the table.",
245
+ "I've got another showing tomorrow. "
246
+ "Give me until then to decide if your number works.",
247
+ "I need to check with my partner before I commit to that price.",
248
+ ],
249
+ "seller_reject": [
250
+ "I can't do that price β€” it doesn't cover what I have into this.",
251
+ "That's too far from asking. I'd rather hold onto it.",
252
+ "I appreciate you trying but that number doesn't work for me at all.",
253
+ ],
254
+ "seller_return_after_walkaway": [
255
+ "Hey, I've been thinking. The other buyer fell through β€” "
256
+ "would you still do ${p:,.0f}?",
257
+ "Circling back β€” other deal didn't pan out. "
258
+ "If ${p:,.0f} is still on the table I'd like to make it work.",
259
+ "The showing yesterday didn't go anywhere. "
260
+ "I'm willing to revisit your ${p:,.0f}.",
261
+ ],
262
+ "seller_urgency": [
263
+ "Someone else is coming to look this weekend. "
264
+ "If you want it at ${p:,.0f} I need to know by tomorrow.",
265
+ "Just so you know I've got two other people interested. "
266
+ "First right of refusal at ${p:,.0f}.",
267
+ "My situation has changed and I need to close this week. "
268
+ "${p:,.0f} only if we finalize today.",
269
+ ],
270
+ "seller_accept": [
271
+ "You know what, ${p:,.0f} works. Let's do it.",
272
+ "Deal. ${p:,.0f} and it's yours.",
273
+ "Alright, I'll take ${p:,.0f}. When can you pick it up?",
274
+ ],
275
+ "seller_exit": [
276
+ "I don't think we're going to get there on price. "
277
+ "Good luck with your search.",
278
+ "We're too far apart. I'm going to wait for a better offer.",
279
+ "I appreciate the interest but this isn't going to work "
280
+ "at your number.",
281
+ ],
282
+ "seller_search": [
283
+ "Let me check if I have something that better fits "
284
+ "what you're describing.",
285
+ "Hold on β€” I think I may have another option in my inventory "
286
+ "that suits your needs.",
287
+ "I want to make sure I'm showing you the best match. "
288
+ "Let me pull some alternatives.",
289
+ ],
290
+ "buyer_open_aggressive": [
291
+ "I'll offer ${p:,.0f} and that's already above what I was "
292
+ "planning to spend.",
293
+ "I can do ${p:,.0f} cash today. "
294
+ "I know that's low but I need to stay in my budget.",
295
+ "First and best offer: ${p:,.0f}. "
296
+ "I've seen similar {item}s go for less.",
297
+ ],
298
+ "buyer_open_strategic": [
299
+ "I've done some research on {item} values in this market. "
300
+ "Based on comps I think ${p:,.0f} is fair.",
301
+ "I'm genuinely interested. I'd like to start at ${p:,.0f} β€” "
302
+ "I think there's a deal here.",
303
+ "Serious buyer, ready to close fast. "
304
+ "With that in mind, ${p:,.0f}.",
305
+ ],
306
+ "buyer_counter_nibble": [
307
+ "Getting closer. Can you do ${p:,.0f}? "
308
+ "That's where I need to be to feel good about the deal.",
309
+ "I'd say yes at ${p:,.0f}. "
310
+ "Throw in the extras and I'll pull the trigger right now.",
311
+ "If you can get to ${p:,.0f} I won't waste any more of "
312
+ "your time β€” deal done.",
313
+ ],
314
+ "buyer_counter_hold": [
315
+ "I've thought about it and I'm still at ${p:,.0f}. "
316
+ "That's genuinely what this is worth to me.",
317
+ "My budget hasn't changed. ${p:,.0f} is the number.",
318
+ "I hear you on the other buyers but ${p:,.0f} is my ceiling.",
319
+ ],
320
+ "buyer_stall": [
321
+ "I need to sleep on it. "
322
+ "I'm also looking at a couple other options this week.",
323
+ "Let me talk to my partner tonight and get back to you tomorrow.",
324
+ "I'm not going to rush into this. Give me a day or two.",
325
+ ],
326
+ "buyer_walkaway": [
327
+ "I don't think we're going to get there. "
328
+ "Thanks for your time β€” good luck with the sale.",
329
+ "I'm going to pass. The price just doesn't work for what I need.",
330
+ "Going to look at other options. "
331
+ "If your price changes, feel free to reach out.",
332
+ ],
333
+ "buyer_return_after_walkaway": [
334
+ "Hey, been thinking about the {item} since we talked. "
335
+ "Is ${p:,.0f} still the best you can do?",
336
+ "Still have the {item} available? "
337
+ "I might stretch to ${p:,.0f} if we can close quickly.",
338
+ "Came back because I couldn't find anything comparable. "
339
+ "Would you take ${p:,.0f}?",
340
+ ],
341
+ "buyer_accept": [
342
+ "Alright, you've got a deal at ${p:,.0f}.",
343
+ "Fine, ${p:,.0f}. Let's stop going back and forth β€” I'll take it.",
344
+ "Done. ${p:,.0f}. When can I come get it?",
345
+ ],
346
+ "buyer_reject": [
347
+ "That's still too high. I can't justify that price.",
348
+ "No, that doesn't work. "
349
+ "I'd need to see a significant move to reconsider.",
350
+ "I'm out at that number. "
351
+ "Not what the market is bearing right now.",
352
+ ],
353
+ "buyer_deadline": [
354
+ "I need to make a decision by end of day β€” "
355
+ "can you give me your absolute best price?",
356
+ "My budget approval expires Friday. "
357
+ "If we agree on ${p:,.0f} right now I can move immediately.",
358
+ "I have to make a call today. "
359
+ "Meet me at ${p:,.0f} and we close this out.",
360
+ ],
361
+ "buyer_search": [
362
+ "Do you have anything else in this category that might "
363
+ "work better for my needs?",
364
+ "I'm not sure this is the right fit. "
365
+ "Do you have other options I should look at?",
366
+ "Before I decide, do you have alternatives β€” "
367
+ "maybe different condition or price point?",
368
+ ],
369
+ }
370
+
371
+ def _t(key: str, item: str = "", p: float = 0,
372
+ avoid: str = "", must: str = "") -> str:
373
+ return random.choice(TEMPLATES[key]).format(
374
+ item=item, p=p, avoid=avoid, must=must
375
+ )
376
+
377
+ # ================================================================
378
+ # STRATEGY PROFILES
379
+ # ================================================================
380
+ BUYER_STRATEGY = {
381
+ "aggressive": {
382
+ "open_discount": (0.55, 0.68), "concession_rate": 0.015,
383
+ "walkaway_prob": 0.35, "return_prob": 0.50,
384
+ "patience": 3, "search_prob": 0.10,
385
+ },
386
+ "patient": {
387
+ "open_discount": (0.72, 0.82), "concession_rate": 0.025,
388
+ "walkaway_prob": 0.15, "return_prob": 0.70,
389
+ "patience": 8, "search_prob": 0.20,
390
+ },
391
+ "skeptical": {
392
+ "open_discount": (0.65, 0.75), "concession_rate": 0.018,
393
+ "walkaway_prob": 0.28, "return_prob": 0.45,
394
+ "patience": 5, "search_prob": 0.30,
395
+ },
396
+ "impulsive": {
397
+ "open_discount": (0.78, 0.88), "concession_rate": 0.040,
398
+ "walkaway_prob": 0.10, "return_prob": 0.30,
399
+ "patience": 2, "search_prob": 0.05,
400
+ },
401
+ "strategic": {
402
+ "open_discount": (0.62, 0.72), "concession_rate": 0.022,
403
+ "walkaway_prob": 0.30, "return_prob": 0.65,
404
+ "patience": 7, "search_prob": 0.25,
405
+ },
406
+ }
407
+
408
+ SELLER_STRATEGY = {
409
+ "firm": {
410
+ "min_discount": 0.93, "concession_rate": 0.008,
411
+ "urgency_prob": 0.15, "return_prob": 0.30, "search_prob": 0.15,
412
+ },
413
+ "motivated": {
414
+ "min_discount": 0.82, "concession_rate": 0.030,
415
+ "urgency_prob": 0.40, "return_prob": 0.60, "search_prob": 0.35,
416
+ },
417
+ "anchoring": {
418
+ "min_discount": 0.90, "concession_rate": 0.010,
419
+ "urgency_prob": 0.25, "return_prob": 0.40, "search_prob": 0.20,
420
+ },
421
+ "collaborative": {
422
+ "min_discount": 0.86, "concession_rate": 0.022,
423
+ "urgency_prob": 0.20, "return_prob": 0.55, "search_prob": 0.40,
424
+ },
425
+ "desperate": {
426
+ "min_discount": 0.75, "concession_rate": 0.045,
427
+ "urgency_prob": 0.60, "return_prob": 0.75, "search_prob": 0.30,
428
+ },
429
+ }
430
+
431
+ # ================================================================
432
+ # DATA GENERATOR
433
+ # ================================================================
434
  def generate_sessions(n_sessions: int) -> List[Dict]:
 
435
  all_rows = []
436
+
437
+ for _ in range(int(n_sessions)):
438
+ cat = random.choice(CATEGORIES)
439
+ item = cat.replace("_", " ").title()
440
+ lp = round(random.uniform(500, 25000), -1)
441
+ sid = f"SYN-{uuid.uuid4().hex[:6].upper()}"
442
+ b_persona = random.choice(BUYER_PERSONAS)
443
+ s_persona = random.choice(SELLER_PERSONAS)
444
+ bs = BUYER_STRATEGY[b_persona]
445
+ ss = SELLER_STRATEGY[s_persona]
446
+ turn = 0
447
+ rows = []
448
+ walked = False
449
+
450
+ b_budget = lp * random.uniform(0.85, 1.05)
451
+ b_estimate = lp * random.uniform(0.65, 0.80)
452
+ s_reserve = lp * random.uniform(0.72, 0.88)
 
 
 
453
 
454
  def add(party, price, mtype, msg):
455
  nonlocal turn
456
  turn += 1
457
+ rows.append({
458
+ "session_id": sid,
459
+ "turn_number": turn,
460
+ "party": party,
461
+ "category": cat,
462
+ "item": item,
463
+ "list_price": lp,
464
+ "offer_price": round(price, 2),
465
+ "msg_type": mtype,
466
+ "message": msg,
467
+ "buyer_persona": b_persona,
468
+ "seller_persona": s_persona,
469
+ "buyer_budget": b_budget,
470
+ "buyer_estimate": b_estimate,
471
+ "seller_reservation": s_reserve,
472
  })
473
 
474
  sp = lp
475
+ bp = round(lp * random.uniform(*bs["open_discount"]), -1)
476
+
477
+ s_tmpl = ("seller_open_motivated"
478
+ if s_persona in ["motivated", "desperate"]
479
+ else "seller_open_firm")
480
+ b_tmpl = ("buyer_open_aggressive"
481
+ if b_persona == "aggressive"
482
+ else "buyer_open_strategic")
483
+
484
+ add(0, sp, "offer", _t(s_tmpl, item=item, p=sp))
485
+ add(1, bp, "counter", _t(b_tmpl, item=item, p=bp))
486
+
487
+ max_turns = random.randint(8, 24)
488
+ prev_sp = sp
489
+ prev_bp = bp
490
+ stall_streak = 0
491
+
492
+ for rnd in range(max_turns):
493
+ gap = sp - bp
494
+ gap_pct = gap / lp if lp > 0 else 0
495
+
496
+ # Natural close
497
+ if gap_pct < 0.03:
498
+ fp = round((sp + bp) / 2, -1)
499
+ if random.random() < 0.75:
500
+ add(random.choice([0, 1]), fp, "accept",
501
+ _t("seller_accept"
502
+ if random.random() < 0.5
503
+ else "buyer_accept", p=fp))
504
+ break
505
+
506
+ # ── Seller turn ───────────────────────────────────
507
+ if random.random() < ss["search_prob"] and rnd > 1:
508
+ add(0, sp, "search", _t("seller_search"))
509
+ match_p = round(sp * random.uniform(0.88, 0.98), -1)
510
+ add(0, match_p, "counter",
511
+ f"I found something that might work better β€” "
512
+ f"similar {item} at ${match_p:,.0f} with better "
513
+ f"specs for your needs.")
514
+ sp = match_p
515
+ stall_streak = 0
516
+ elif random.random() < ss["urgency_prob"] and rnd > 1:
517
+ add(0, sp, "stall", _t("seller_urgency", item=item, p=sp))
518
+ stall_streak += 1
519
+ elif gap_pct > 0.30:
520
+ add(0, sp, "reject", _t("seller_reject"))
521
+ elif prev_sp == sp and stall_streak < 2:
522
+ add(0, sp, "stall", _t("seller_stall"))
523
+ stall_streak += 1
524
+ else:
525
+ concede_s = (ss["concession_rate"] * lp
526
+ * random.uniform(0.5, 1.5))
527
+ sp = max(max(bp + gap * 0.15, sp - concede_s), s_reserve)
528
+ sp = round(sp, -1)
529
+ tmpl = ("seller_counter_concede"
530
+ if concede_s > lp * 0.02
531
+ else "seller_counter_hold")
532
+ add(0, sp, "counter", _t(tmpl, p=sp))
533
+ stall_streak = 0
534
+
535
+ prev_sp = sp
536
+ gap = sp - bp
537
+
538
+ # ── Buyer turn ────────────────────────────────────
539
+ concede_b = (bs["concession_rate"] * lp
540
+ * random.uniform(0.5, 1.5))
541
+
542
+ if (random.random() < bs["search_prob"]
543
+ and gap_pct > 0.12 and rnd > 1):
544
+ add(1, bp, "search", _t("buyer_search"))
545
+ new_bp = round(bp * random.uniform(1.01, 1.06), -1)
546
+ add(1, new_bp, "counter",
547
+ f"I looked at your alternatives β€” I could do "
548
+ f"${new_bp:,.0f} for the right {item} with the "
549
+ f"features I need.")
550
+ bp = new_bp
551
+
552
+ elif (not walked
553
+ and random.random() < bs["walkaway_prob"]
554
+ and rnd > 2):
555
+ walked = True
556
+ add(1, bp, "exit", _t("buyer_walkaway"))
557
+ if random.random() < bs["return_prob"]:
558
+ rp = round(bp * 1.04, -1)
559
+ add(1, rp, "counter",
560
+ _t("buyer_return_after_walkaway",
561
+ item=item, p=rp))
562
+ bp = rp
563
+ else:
564
+ break
565
+
566
+ elif rnd > bs["patience"] and random.random() < 0.30:
567
+ bp = min(sp - gap * 0.1, bp + concede_b)
568
+ bp = min(bp, b_budget)
569
+ bp = round(bp, -1)
570
+ add(1, bp, "counter", _t("buyer_deadline", p=bp))
571
+
572
+ elif gap_pct < 0.08 and random.random() < 0.40:
573
+ add(1, bp, "counter", _t("buyer_counter_nibble", p=bp))
574
+
575
+ elif random.random() < 0.15:
576
+ add(1, bp, "stall", _t("buyer_stall"))
577
+
578
+ elif prev_bp == bp and random.random() < 0.35:
579
+ add(1, bp, "counter", _t("buyer_counter_hold", p=bp))
580
+
581
+ else:
582
+ bp = min(bp + concede_b, b_budget)
583
+ bp = min(sp - gap * 0.15, bp)
584
+ bp = round(bp, -1)
585
+ add(1, bp, "counter", _t("buyer_counter_nibble", p=bp))
586
+
587
+ prev_bp = bp
588
+
589
+ if gap / lp > 0.45:
590
+ add(1, bp, "exit", _t("buyer_reject"))
591
+ if random.random() < ss["return_prob"]:
592
+ new_sp = round(sp * 0.94, -1)
593
+ add(0, new_sp, "counter",
594
+ _t("seller_return_after_walkaway", p=new_sp))
595
+ sp = new_sp
596
+ else:
597
+ break
598
  else:
599
+ if (sp - bp) / lp < 0.08:
600
+ fp = round((sp + bp) / 2, -1)
601
+ add(random.choice([0, 1]), fp, "accept",
602
+ _t("seller_accept", p=fp))
603
+ else:
604
+ add(1, bp, "exit", _t("buyer_walkaway"))
605
+
606
+ all_rows.extend(rows)
607
+
608
  return all_rows
609
 
610
+ # ================================================================
611
+ # FEATURE EXTRACTION β€” all list guards in place
612
+ # ================================================================
613
+ def extract_features(turns, idx, lp,
614
+ b_budget=0, b_estimate=0, s_reserve=0):
615
+ hist = turns[:idx]
616
+ if len(hist) < 1:
617
+ return [0.0] * 10
618
+
619
+ sp_prices = [r["offer_price"] for r in hist if int(r["party"]) == 0]
620
+ bp_prices = [r["offer_price"] for r in hist if int(r["party"]) == 1]
621
+
622
+ s_vel = ((sp_prices[-1] - sp_prices[0]) / lp) \
623
+ if len(sp_prices) > 1 else 0.0
624
+ b_vel = ((bp_prices[-1] - bp_prices[0]) / lp) \
625
+ if len(bp_prices) > 1 else 0.0
626
+
627
+ gap_r = ((sp_prices[-1] - bp_prices[-1]) / lp) \
628
+ if (sp_prices and bp_prices) else 1.0
629
+
630
+ s_con = sum(
631
+ max(0, sp_prices[i-1] - sp_prices[i])
632
+ for i in range(1, len(sp_prices))
633
+ ) / lp if len(sp_prices) > 1 else 0.0
634
+
635
+ b_con = sum(
636
+ max(0, bp_prices[i] - bp_prices[i-1])
637
+ for i in range(1, len(bp_prices))
638
+ ) / lp if len(bp_prices) > 1 else 0.0
639
+
640
+ stalls = (sum(1 for r in hist if r["msg_type"] == "stall")
641
+ / max(len(hist), 1))
642
+ searches = (sum(1 for r in hist if r["msg_type"] == "search")
643
+ / max(len(hist), 1))
644
+
645
+ # Bound-relative β€” guarded against empty lists
646
+ budget_dist = min(
647
+ (bp_prices[-1] - b_estimate) / max(b_budget - b_estimate, 1), 2.0
648
+ ) if (b_budget > 0 and bp_prices) else 0.0
649
+
650
+ floor_dist = min(
651
+ (sp_prices[-1] - s_reserve) / max(lp - s_reserve, 1), 1.5
652
+ ) if (s_reserve > 0 and sp_prices) else 0.5
653
+
654
+ turns_norm = min(idx / 25.0, 1.0)
655
+
656
+ return [
657
+ float(s_vel - b_vel),
658
+ float(min(max(gap_r, 0.0), 2.0)),
659
+ float(min(s_con, 2.0)),
660
+ float(min(b_con, 2.0)),
661
+ float(stalls),
662
+ float(searches),
663
+ float(budget_dist),
664
+ float(floor_dist),
665
+ float(turns_norm),
666
+ 0.0,
667
+ ]
668
+
669
+ # ================================================================
670
+ # DATASET BUILDER β€” selective pin_memory (small tensors only)
671
+ # ================================================================
672
+ def build_pinned_dataset(rows: List[Dict]) -> TensorDataset:
673
+ sessions = {}
674
+ for r in rows:
675
+ sessions.setdefault(r["session_id"], []).append(r)
676
+
677
+ (texts, party_l, cat_l, ofn_l, tn_l,
678
+ msg_l, pt_l, bp_l, sp_l, mom_l) = ([] for _ in range(10))
679
+
680
+ for turns in sessions.values():
681
+ turns = sorted(turns, key=lambda x: int(x["turn_number"]))
682
+ lp = float(turns[0]["list_price"])
683
+ if lp <= 0:
684
+ continue
685
+ b_bud = float(turns[0].get("buyer_budget", lp))
686
+ b_est = float(turns[0].get("buyer_estimate", lp * 0.75))
687
+ s_res = float(turns[0].get("seller_reservation", lp * 0.80))
688
+
689
+ for i in range(1, len(turns)):
690
+ tgt = turns[i]
691
+ recent = turns[max(0, i-3):i]
692
+ text = " [SEP] ".join(
693
+ f"{'S' if int(t['party'])==0 else 'B'}: {t['message']}"
694
+ for t in recent
695
+ )
696
+ mom = extract_features(turns, i, lp, b_bud, b_est, s_res)
697
+
698
+ texts.append(text)
699
+ party_l.append(int(tgt["party"]))
700
+ cat_l.append(CAT2IDX.get(tgt["category"], 0))
701
+ ofn_l.append(min(float(tgt["offer_price"]) / lp, 3.0))
702
+ tn_l.append(min(int(tgt["turn_number"]) / 25.0, 1.0))
703
+ msg_l.append(MSG2IDX.get(tgt["msg_type"], 1))
704
+ pt_l.append(min(float(tgt["offer_price"]) / lp, 3.0))
705
+ bp_l.append(BPERSONA2IDX.get(
706
+ tgt.get("buyer_persona", "patient"), 1))
707
+ sp_l.append(SPERSONA2IDX.get(
708
+ tgt.get("seller_persona", "firm"), 0))
709
+ mom_l.append(mom)
710
+
711
+ del sessions, rows
712
+ gc.collect()
713
+
714
+ n = len(texts)
715
+ input_ids = torch.empty((n, MAX_LEN), dtype=torch.long)
716
+ attn_mask = torch.empty((n, MAX_LEN), dtype=torch.long)
717
+
718
+ for i in range(0, n, 20000):
719
+ chunk = texts[i : i + 20000]
720
+ enc = tokenizer(
721
+ chunk, max_length=MAX_LEN,
722
+ padding="max_length", truncation=True,
723
+ return_tensors="pt"
724
+ )
725
+ input_ids[i : i + 20000] = enc["input_ids"]
726
+ attn_mask[i : i + 20000] = enc["attention_mask"]
727
+
728
+ del texts
729
+ gc.collect()
730
+
731
+ tensors = dict(
732
+ ids = input_ids,
733
+ mask = attn_mask,
734
+ pty = torch.tensor(party_l, dtype=torch.long),
735
+ cat = torch.tensor(cat_l, dtype=torch.long),
736
+ ofn = torch.tensor(ofn_l, dtype=torch.float),
737
+ tn = torch.tensor(tn_l, dtype=torch.float),
738
+ mt = torch.tensor(msg_l, dtype=torch.long),
739
+ pt = torch.tensor(pt_l, dtype=torch.float),
740
+ bp = torch.tensor(bp_l, dtype=torch.long),
741
+ sp = torch.tensor(sp_l, dtype=torch.long),
742
+ mom = torch.tensor(mom_l, dtype=torch.float),
743
+ )
744
+ del party_l, cat_l, ofn_l, tn_l, msg_l, pt_l, bp_l, sp_l, mom_l
745
+ gc.collect()
746
+
747
+ # ── Selective pin_memory ──────────────────────────────────
748
+ # ids + mask are ~400 MB each β€” pinning them causes the CUDA
749
+ # driver to reserve matching GPU-side DMA staging buffers,
750
+ # blowing VRAM before training even starts.
751
+ # Only pin the small scalar tensors; they transfer instantly
752
+ # and get the DMA benefit without the memory cost.
753
+ if DEVICE.type == "cuda":
754
+ SMALL_KEYS = {"pty","cat","ofn","tn","mt","pt","bp","sp","mom"}
755
+ tensors = {
756
+ k: (v.pin_memory() if k in SMALL_KEYS else v)
757
+ for k, v in tensors.items()
758
  }
759
 
760
+ return TensorDataset(*tensors.values())
761
+
762
+ # ================================================================
763
+ # MODEL
764
+ # ================================================================
765
  class PositionalEncoding(nn.Module):
766
  def __init__(self, d: int, max_len: int = 512):
767
  super().__init__()
768
  self.drop = nn.Dropout(0.1)
769
+ pe = torch.zeros(max_len, d)
770
  pos = torch.arange(max_len).unsqueeze(1).float()
771
+ div = torch.exp(
772
+ torch.arange(0, d, 2).float() * (-math.log(10000.0) / d)
773
+ )
774
  pe[:, 0::2] = torch.sin(pos * div)
775
  pe[:, 1::2] = torch.cos(pos * div)
776
  self.register_buffer("pe", pe.unsqueeze(0))
777
 
778
+ def forward(self, x):
779
+ return self.drop(x + self.pe[:, :x.size(1)])
780
+
781
+
782
+ class MomentumEncoder(nn.Module):
783
+ def __init__(self, in_dim: int = 10, out_dim: int = 48):
784
+ super().__init__()
785
+ self.net = nn.Sequential(
786
+ nn.Linear(in_dim, 64), nn.GELU(),
787
+ nn.Linear(64, out_dim)
788
+ )
789
+ def forward(self, x): return self.net(x)
790
+
791
 
792
  class NegotiationTransformer(nn.Module):
793
  def __init__(self):
794
  super().__init__()
795
+ self.emb = nn.Embedding(30522, D_MODEL, padding_idx=0)
796
+ self.pos = PositionalEncoding(D_MODEL)
797
+ enc_layer = nn.TransformerEncoderLayer(
798
+ D_MODEL, N_HEADS, FFN_DIM,
799
+ dropout=0.1, batch_first=True, norm_first=True
800
+ )
801
+ self.encoder = nn.TransformerEncoder(enc_layer, N_LAYERS)
802
+ self.p_emb = nn.Embedding(2, 32)
803
+ self.c_emb = nn.Embedding(len(CATEGORIES), 64)
804
+ self.bp_emb = nn.Embedding(len(BUYER_PERSONAS), 32)
805
+ self.sp_emb = nn.Embedding(len(SELLER_PERSONAS), 32)
806
+ self.mom_enc = MomentumEncoder(10, 48)
807
+ total_ctx = D_MODEL + 32 + 64 + 32 + 32 + 48 + 2
808
+ self.fusion = nn.Sequential(
809
+ nn.Linear(total_ctx, D_MODEL), nn.GELU(), nn.Dropout(0.1)
810
+ )
811
  self.msg_head = nn.Linear(D_MODEL, len(MSG_TYPES))
812
+ self.px_head = nn.Sequential(
813
+ nn.Linear(D_MODEL, 128), nn.GELU(),
814
+ nn.Linear(128, 1), nn.Softplus()
815
+ )
816
 
817
+ def forward(self, ids, mask, party, cat, ofn, tn, bp, sp, mom):
818
+ x = self.pos(self.emb(ids))
819
+ x = self.encoder(x, src_key_padding_mask=(mask == 0))
820
  cls = x[:, 0]
821
+ ctx = torch.cat([
822
+ cls,
823
+ self.p_emb(party),
824
+ self.c_emb(cat),
825
+ self.bp_emb(bp),
826
+ self.sp_emb(sp),
827
+ self.mom_enc(mom),
828
+ torch.stack([ofn, tn], dim=1),
829
+ ], dim=1)
830
+ f = self.fusion(ctx)
831
  return self.msg_head(f), self.px_head(f).squeeze(1)
832
 
833
+
834
+ class AsymmetricNegotiationLoss(nn.Module):
835
+ def __init__(self):
836
+ super().__init__()
837
+ # [offer, counter, accept, reject, exit, stall, search]
838
+ self.seller_w = torch.tensor([1.0,1.0,1.5,1.2,1.3,0.8,1.1])
839
+ self.buyer_w = torch.tensor([1.0,1.0,1.3,1.0,1.2,0.9,1.2])
840
+
841
+ def forward(self, mt_logits, mt_targets, px_pred, px_targets, party):
842
+ dev = mt_logits.device
843
+ sw = self.seller_w.to(dev)
844
+ bw = self.buyer_w.to(dev)
845
+ loss_mt = torch.zeros(mt_logits.size(0), device=dev)
846
+ sm = (party == 0)
847
+ bm = (party == 1)
848
+ if sm.any():
849
+ loss_mt[sm] = F.cross_entropy(
850
+ mt_logits[sm], mt_targets[sm],
851
+ weight=sw, reduction="none"
852
+ )
853
+ if bm.any():
854
+ loss_mt[bm] = F.cross_entropy(
855
+ mt_logits[bm], mt_targets[bm],
856
+ weight=bw, reduction="none"
857
+ )
858
+ return loss_mt.mean() + 0.5 * F.mse_loss(px_pred, px_targets)
859
+
860
+ # ================================================================
861
+ # PLOT
862
+ # ================================================================
863
+ def plot_curve(losses):
864
+ fig, ax = plt.subplots(figsize=(6, 3))
865
+ if losses:
866
+ ax.plot(range(1, len(losses)+1), losses, "b-o", markersize=4)
867
+ ax.set_title("Training Loss")
868
+ else:
869
+ ax.text(0.5, 0.5, "No data yet",
870
+ ha="center", va="center", alpha=0.5)
871
+ ax.grid(alpha=0.3)
872
+ plt.tight_layout()
873
+ return fig
874
+
875
+ # ================================================================
876
+ # TRAINING
877
+ # ================================================================
878
+ def run_training(n_sessions, epochs, batch_size, lr):
879
  global GLOBAL_MODEL
880
+ logs = []
881
+
882
+ def log(msg):
883
+ ts = time.strftime("%H:%M:%S")
884
+ line = f"[{ts}] {msg}"
885
+ logs.append(line)
886
+ if len(logs) > 20:
887
+ logs.pop(0)
888
+ print(line)
889
+ return "\n".join(logs)
890
+
891
  try:
892
+ batch_size = int(batch_size)
893
+ log_txt = log(f"Generating {int(n_sessions):,} sessions...")
894
+ yield "🟑 Generating...", log_txt, plot_curve([]), "❌ Needs Training"
895
+
896
+ rows = generate_sessions(int(n_sessions))
897
+ log_txt = log(f"Generated {len(rows):,} rows. Building dataset...")
898
+ yield "🟑 Tokenizing...", log_txt, plot_curve([]), "❌ Needs Training"
899
+
900
+ dataset = build_pinned_dataset(rows)
901
+ loader = DataLoader(
902
+ dataset, batch_size=batch_size,
903
+ shuffle=True, num_workers=0,
904
+ pin_memory=False, drop_last=True
905
+ )
 
 
 
 
 
 
 
906
  total_batches = len(loader)
907
+
908
+ log_txt = log(f"Dataset: {len(dataset):,} samples | "
909
+ f"{total_batches} batches | bs={batch_size}")
910
+ yield "🟑 Building model...", log_txt, plot_curve([]), "❌ Needs Training"
911
+
912
+ model = NegotiationTransformer().to(DEVICE)
913
+ crit = AsymmetricNegotiationLoss()
914
+
915
+ if hasattr(torch, "compile") and DEVICE.type == "cuda":
916
+ try:
917
+ model = torch.compile(model, backend="cudagraphs")
918
+ log_txt = log("torch.compile (cudagraphs) applied")
919
+ except Exception as ce:
920
+ log_txt = log(f"compile skipped: {ce}")
921
+
922
+ opt = AdamW(model.parameters(), lr=float(lr), weight_decay=1e-2)
923
+ sch = CosineAnnealingLR(opt, T_max=int(epochs))
924
+ scaler = torch.cuda.amp.GradScaler()
925
+ losses = []
926
+
927
+ log_txt = log("πŸš€ Training started")
928
+ yield "🟒 Training...", log_txt, plot_curve([]), "❌ Needs Training"
929
+
930
+ for ep in range(int(epochs)):
931
  model.train()
932
  ep_loss = 0.0
933
+ t0 = time.time()
934
+
 
935
  for i, batch in enumerate(loader):
936
+ (b_ids, b_mask, b_pty, b_cat, b_ofn,
937
+ b_tn, b_mt, b_pt, b_bp, b_sp, b_mom) = [
938
+ t.to(DEVICE, non_blocking=True) for t in batch
939
+ ]
940
+
941
+ if i % 100 == 0:
942
+ el = time.time() - t0
943
+ ms_b = (el / max(i, 1)) * 1000
944
+ status = (f"🟒 Epoch {ep+1}/{int(epochs)} | "
945
+ f"Batch {i}/{total_batches} | "
946
+ f"{ms_b:.0f}ms/batch")
947
+ log_txt = log(status)
948
+ yield (status, log_txt,
949
+ plot_curve(losses), "❌ Needs Training")
950
+
951
+ opt.zero_grad(set_to_none=True)
952
+
953
+ with torch.cuda.amp.autocast():
954
+ mt_logits, px_pred = model(
955
+ b_ids, b_mask, b_pty, b_cat,
956
+ b_ofn, b_tn, b_bp, b_sp, b_mom
957
+ )
958
+ loss = crit(mt_logits, b_mt, px_pred, b_pt, b_pty)
959
+
960
+ scaler.scale(loss).backward()
961
+ scaler.unscale_(opt)
962
  nn.utils.clip_grad_norm_(model.parameters(), 1.0)
963
+ scaler.step(opt)
964
+ scaler.update()
965
  ep_loss += loss.item()
966
+
967
  sch.step()
968
+ avg = ep_loss / max(total_batches, 1)
969
+ et = time.time() - t0
970
+ losses.append(avg)
971
+ log_txt = log(
972
+ f"Epoch {ep+1}/{int(epochs)} done β€” "
973
+ f"loss: {avg:.4f} | {et:.1f}s | "
974
+ f"{et/total_batches*1000:.0f}ms/batch"
975
+ )
976
+ yield (f"🟒 Epoch {ep+1} done", log_txt,
977
+ plot_curve(losses), "❌ Needs Training")
978
 
 
979
  model.eval()
980
  GLOBAL_MODEL = model
981
+ log_txt = log("βœ… Training complete.")
982
+ yield "πŸ”΅ Complete", log_txt, plot_curve(losses), "βœ… Ready"
983
+
984
  except Exception as e:
985
+ import traceback
986
+ log_txt = log(f"ERROR: {e}\n{traceback.format_exc()}")
987
+ yield "πŸ”΄ ERROR", log_txt, plot_curve([]), "❌ Needs Training"
988
+
989
+ # ================================================================
990
+ # INFERENCE ENGINE
991
+ # ================================================================
992
+ def _build_message(msg_type, price, item,
993
+ is_buyer, persona, inv_context=""):
994
+ p = price
995
+ if msg_type == "search":
996
+ return _t("buyer_search") if is_buyer else _t("seller_search")
997
+
 
 
 
 
 
 
 
 
998
  if is_buyer:
999
+ m = {
1000
+ "offer": _t("buyer_open_strategic", item=item, p=p),
1001
+ "counter": _t("buyer_counter_nibble", p=p),
1002
+ "accept": _t("buyer_accept", p=p),
1003
+ "reject": _t("buyer_reject"),
1004
+ "exit": _t("buyer_walkaway"),
1005
+ "stall": _t("buyer_stall"),
1006
  }
1007
+ if persona == "aggressive":
1008
+ m["offer"] = _t("buyer_open_aggressive", item=item, p=p)
1009
+ m["counter"] = _t("buyer_counter_hold", p=p)
1010
  else:
1011
+ m = {
1012
+ "offer": _t("seller_open_firm", item=item, p=p),
1013
+ "counter": _t("seller_counter_hold", p=p),
1014
+ "accept": _t("seller_accept", p=p),
1015
+ "reject": _t("seller_reject"),
1016
+ "exit": _t("seller_exit"),
1017
+ "stall": _t("seller_stall"),
1018
  }
1019
+ if persona in ["motivated", "desperate"]:
1020
+ m["offer"] = _t("seller_open_motivated", item=item, p=p)
1021
+ m["counter"] = _t("seller_counter_concede", p=p)
1022
+ if inv_context:
1023
+ m["search"] = (
1024
+ "Let me check my inventory...\n"
1025
+ f"{inv_context}\nWould any of these work for you?"
1026
+ )
1027
+ return m.get(msg_type, f"{msg_type} @ ${p:,.2f}")
1028
+
1029
 
1030
+ def run_inference_turn(
1031
+ session_state,
1032
+ category, item,
1033
+ list_price, user_price, user_message,
1034
+ user_party, user_persona, ai_persona,
1035
+ buyer_budget, buyer_estimate,
1036
+ buyer_avoids, buyer_must_have,
1037
+ seller_reservation, seller_urgency,
1038
+ ):
1039
  if GLOBAL_MODEL is None:
1040
+ return (session_state,
1041
+ session_state.get("history_ui", []),
1042
+ "Model not trained.", "", "", "", "")
1043
+
1044
+ lp = float(list_price)
1045
+ is_user_buyer = (user_party == "Buyer")
1046
+ ai_party_int = 0 if is_user_buyer else 1
1047
+
1048
+ # ── Initialise session ────────────────────────────────────
1049
+ if not session_state.get("started"):
1050
+ init_bp = (float(buyer_estimate)
1051
+ if float(buyer_estimate) > 0
1052
+ else round(lp * 0.75, -1))
1053
+ session_state = {
1054
+ "started": True,
1055
+ "turn": 0,
1056
+ "sp": lp,
1057
+ "bp": init_bp,
1058
+ "history": [],
1059
+ "history_ui": [],
1060
+ "status": "active",
1061
+ "inv_context": "",
1062
+ }
1063
+
1064
+ if session_state["status"] != "active":
1065
+ return (session_state, session_state["history_ui"],
1066
+ "Session ended β€” click New Session to restart.",
1067
+ "", "", "", "")
1068
+
1069
+ history = session_state["history"]
1070
+ history_ui = session_state["history_ui"]
1071
+ sp = float(session_state["sp"])
1072
+ bp = float(session_state["bp"])
1073
+ turn = session_state["turn"]
1074
+
1075
+ b_bud = float(buyer_budget) if float(buyer_budget) > 0 else lp
1076
+ b_est = float(buyer_estimate) if float(buyer_estimate) > 0 else lp * 0.75
1077
+ s_res = float(seller_reservation) if float(seller_reservation) > 0 else lp * 0.80
1078
+
1079
+ # ── Record user turn ──────────────────────────────────────
1080
+ u_int = 1 if is_user_buyer else 0
1081
+ history.append({
1082
+ "party": u_int,
1083
+ "message": user_message,
1084
+ "offer_price": float(user_price),
1085
+ "msg_type": "counter",
1086
+ "turn_number": turn + 1,
1087
+ })
1088
+ history_ui.append((
1089
+ f"{'πŸ§‘ You (Buyer)' if is_user_buyer else 'πŸ§‘ You (Seller)'}"
1090
+ f" [${float(user_price):,.0f}]: {user_message}",
1091
+ None
1092
+ ))
1093
+ turn += 1
1094
+
1095
+ if is_user_buyer:
1096
+ bp = float(user_price)
1097
  else:
1098
+ sp = float(user_price)
1099
+
1100
+ # ── Build momentum features ───────────────────────────────
1101
+ sp_prices = [r["offer_price"] for r in history if int(r["party"]) == 0]
1102
+ bp_prices = [r["offer_price"] for r in history if int(r["party"]) == 1]
1103
+
1104
+ s_vel = ((sp_prices[-1]-sp_prices[0])/lp) if len(sp_prices)>1 else 0.0
1105
+ b_vel = ((bp_prices[-1]-bp_prices[0])/lp) if len(bp_prices)>1 else 0.0
1106
+ gap_r = ((sp - bp) / lp) if sp > bp else 0.0
1107
+ stalls = (sum(1 for r in history if r["msg_type"] == "stall")
1108
+ / max(len(history), 1))
1109
+ srch = (sum(1 for r in history if r["msg_type"] == "search")
1110
+ / max(len(history), 1))
1111
+ b_dist = min((bp - b_est) / max(b_bud - b_est, 1), 2.0) \
1112
+ if (b_bud > 0 and bp_prices) else 0.0
1113
+ f_dist = min((sp - s_res) / max(lp - s_res, 1), 1.5) \
1114
+ if (s_res > 0 and sp_prices) else 0.5
1115
+
1116
+ mom = [
1117
+ float(s_vel - b_vel),
1118
+ float(min(max(gap_r, 0.0), 2.0)),
1119
+ 0.0, 0.0,
1120
+ float(stalls), float(srch),
1121
+ float(b_dist), float(f_dist),
1122
+ float(min(turn / 25.0, 1.0)),
1123
+ 0.0,
1124
+ ]
1125
+
1126
+ # ── Build text context ────────────────────────────────────
1127
+ inv_ctx = session_state.get("inv_context", "")
1128
+ recent = history[-3:]
1129
+ text = " [SEP] ".join(
1130
+ f"{'S' if int(r['party'])==0 else 'B'}: {r['message']}"
1131
+ for r in recent
1132
+ )
1133
+ if inv_ctx:
1134
+ text = f"[INV: {inv_ctx[:120]}] " + text
1135
+
1136
+ enc = tokenizer(
1137
+ text, max_length=MAX_LEN,
1138
+ padding="max_length", truncation=True,
1139
+ return_tensors="pt"
1140
+ )
1141
+
1142
+ dev = DEVICE
1143
+ ai_pty_t = torch.tensor([ai_party_int], dtype=torch.long).to(dev)
1144
+ cat_t = torch.tensor([CAT2IDX.get(category, 0)],
1145
+ dtype=torch.long).to(dev)
1146
+ ofn_t = torch.tensor([min(float(user_price)/lp, 3.0)],
1147
+ dtype=torch.float).to(dev)
1148
+ tn_t = torch.tensor([min(turn/25.0, 1.0)],
1149
+ dtype=torch.float).to(dev)
1150
+ bp_idx = BPERSONA2IDX.get(
1151
+ user_persona if is_user_buyer else ai_persona, 1)
1152
+ sp_idx = SPERSONA2IDX.get(
1153
+ ai_persona if is_user_buyer else user_persona, 0)
1154
+ bp_t = torch.tensor([bp_idx], dtype=torch.long).to(dev)
1155
+ sp_t = torch.tensor([sp_idx], dtype=torch.long).to(dev)
1156
+ mom_t = torch.tensor([mom], dtype=torch.float).to(dev)
1157
+
1158
+ with torch.no_grad():
1159
+ mt_logits, px = GLOBAL_MODEL(
1160
+ enc["input_ids"].to(dev),
1161
+ enc["attention_mask"].to(dev),
1162
+ ai_pty_t, cat_t, ofn_t, tn_t, bp_t, sp_t, mom_t
1163
+ )
1164
+
1165
+ mt_idx = mt_logits.argmax(dim=1).item()
1166
+ msg_type = IDX2MSG[mt_idx]
1167
+ ai_price = round(float(px.item()) * lp, 2)
1168
+
1169
+ # ── Clamp AI price to valid range ─────────────────────────
1170
+ if ai_party_int == 0: # AI is seller
1171
+ ai_price = max(ai_price, s_res * 1.005)
1172
+ ai_price = min(ai_price, lp * 1.05)
1173
+ sp = ai_price
1174
+ else: # AI is buyer
1175
+ ai_price = min(ai_price, b_bud)
1176
+ ai_price = min(ai_price, sp * 0.99)
1177
+ ai_price = max(ai_price, lp * 0.25)
1178
+ bp = ai_price
1179
+
1180
+ # ── Execute inventory search if triggered ─────────────────
1181
+ inv_context_text = ""
1182
+ if msg_type == "search":
1183
+ if ai_party_int == 0: # Seller searches for buyer
1184
+ results = search_inventory(
1185
+ category = category,
1186
+ max_price = b_bud if b_bud > 0 else lp * 1.1,
1187
+ min_price = b_est * 0.8 if b_est > 0 else 0,
1188
+ keywords = buyer_must_have,
1189
+ avoids = buyer_avoids,
1190
+ top_k = 3,
1191
+ )
1192
+ inv_context_text = format_inventory_context(
1193
+ results, reveal_floor=True
1194
+ )
1195
+ else: # Buyer searches seller inventory
1196
+ results = search_inventory(
1197
+ category = category,
1198
+ max_price = b_bud if b_bud > 0 else lp,
1199
+ keywords = buyer_must_have,
1200
+ avoids = buyer_avoids,
1201
+ top_k = 3,
1202
+ )
1203
+ inv_context_text = format_inventory_context(
1204
+ results, reveal_floor=False
1205
+ )
1206
+ session_state["inv_context"] = inv_context_text
1207
+
1208
+ # ── Build AI message ──────────────────────────────────────
1209
+ ai_msg = _build_message(
1210
+ msg_type, ai_price, item,
1211
+ not is_user_buyer, ai_persona,
1212
+ inv_context_text
1213
+ )
1214
+ if msg_type == "search" and inv_context_text:
1215
+ ai_msg += (f"\n\nπŸ“¦ **Inventory Results:**\n"
1216
+ f"```\n{inv_context_text}\n```")
1217
+
1218
+ history.append({
1219
+ "party": ai_party_int,
1220
+ "message": ai_msg,
1221
+ "offer_price": ai_price,
1222
+ "msg_type": msg_type,
1223
+ "turn_number": turn + 1,
1224
+ })
1225
+ ai_label = (f"πŸ€– AI ({'Seller' if ai_party_int==0 else 'Buyer'}) "
1226
+ f"[{ai_persona}]")
1227
+ history_ui.append((None, f"{ai_label} [${ai_price:,.0f}]: {ai_msg}"))
1228
+ turn += 1
1229
 
1230
+ # ── ZOPA ──────────────────────────────────────────────────
1231
+ zopa = bp - s_res
1232
+ zopa_str = (f"βœ… ZOPA: +${zopa:,.0f} (deal zone exists)"
1233
+ if zopa > 0
1234
+ else f"❌ ZOPA: ${zopa:,.0f} (no overlap yet)")
1235
 
1236
+ # ── Terminal check ────────────────────────────────────────
1237
+ status = "active"
1238
+ if msg_type == "accept":
1239
+ status = "closed"
1240
+ history_ui.append(
1241
+ (None, f"βœ… **DEAL CLOSED at ${ai_price:,.0f}**")
1242
+ )
1243
+ elif msg_type == "exit":
1244
+ status = "ended"
1245
+ history_ui.append((None, "❌ Negotiation ended"))
1246
+
1247
+ probs = F.softmax(mt_logits, dim=1).squeeze().tolist()
1248
+ prob_str = " | ".join(
1249
+ f"{MSG_TYPES[i]}: {probs[i]:.2f}" for i in range(len(MSG_TYPES))
1250
+ )
1251
+ gap_pct = abs(sp - bp) / lp * 100
1252
+ summary = (f"Turn {turn} | Gap: {gap_pct:.1f}% | "
1253
+ f"Seller: ${sp:,.0f} | Buyer: ${bp:,.0f} | {zopa_str}")
1254
+
1255
+ session_state.update({
1256
+ "turn": turn,
1257
+ "sp": sp,
1258
+ "bp": bp,
1259
+ "history": history,
1260
+ "history_ui": history_ui,
1261
+ "status": status,
1262
+ })
1263
+
1264
+ return (session_state, history_ui, summary,
1265
+ msg_type, f"${ai_price:,.2f}", prob_str, inv_context_text)
1266
+
1267
+
1268
+ def reset_session():
1269
+ return {}, [], "Session reset.", "", "", "", ""
1270
+
1271
+ # ================================================================
1272
+ # STRATEGY GUIDES
1273
+ # ================================================================
1274
+ BUYER_GUIDE = """### πŸ“‹ Buyer Playbook
1275
+ **Bounds to set before starting:**
1276
+ - **Budget** β€” your true ceiling. Encoded as soft penalty, not hard wall.
1277
+ - **Estimate** β€” fair value anchor. Sets your opening offer range.
1278
+ - **Must-have features** β€” filters inventory search. e.g. *bluetooth, low miles*
1279
+ - **Hard avoids** β€” instant deal-breakers. e.g. *salvage title, high mileage*
1280
+
1281
+ **Tactics the model trains on:**
1282
+ - πŸ”΄ Aggressive open at 55-65% of ask
1283
+ - πŸšͺ Walk away at turn 3-4, return with prior offer
1284
+ - πŸ” Trigger search when gap > 12%: *"Do you have anything else?"*
1285
+ - ⏰ Deadline pressure after patience threshold
1286
+ - πŸͺ Nibble for extras when gap < 8%
1287
+ - 🀝 Strategic persona: cite comps, build rapport"""
1288
+
1289
+ SELLER_GUIDE = """### πŸ“‹ Seller Playbook
1290
+ **Bounds to set before starting:**
1291
+ - **Reservation price** β€” private floor. Model NEVER accepts below this.
1292
+ - **Urgency** β€” high urgency raises concession rate and search frequency.
1293
+ - **Inventory** β€” pre-loaded. Searched when buyer asks for alternatives.
1294
+
1295
+ **Tactics the model trains on:**
1296
+ - βš“ Open 15-20% above target
1297
+ - πŸ‘₯ Social proof: *"Two other buyers this weekend"*
1298
+ - πŸ” Proactively search inventory when buyer signals dissatisfaction
1299
+ - ⏰ Urgency close: *"Need to close by Friday"*
1300
+ - πŸ“ž Return after walkaway with small concession
1301
+ - πŸ“‰ Shrinking concessions signal approaching floor"""
1302
+
1303
+ # ================================================================
1304
+ # UI
1305
+ # ================================================================
1306
+ with gr.Blocks(title="ANP v5 | Bounded Negotiation",
1307
+ theme=gr.themes.Soft()) as demo:
1308
+ gr.Markdown(
1309
+ "# ANP v5 β€” Bounded Negotiation Engine\n"
1310
+ "Buyer bounds Β· Seller reservation Β· Inventory tool use Β· "
1311
+ "ZOPA tracking Β· Persona conditioning"
1312
+ )
1313
+
1314
+ # ── Training Tab ──────────────────────────────────────────
1315
+ with gr.Tab("πŸ‹οΈ Training"):
1316
  with gr.Row():
1317
+ n_sessions = gr.Number(value=20000, label="Sessions")
1318
  epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
1319
+ batch_size = gr.Slider(64, 1024, value=512, step=64,
1320
+ label="Batch Size")
1321
+ lr = gr.Number(value=3e-4, label="LR")
1322
+ tr_btn = gr.Button("πŸš€ Train", variant="primary")
1323
+ status_box = gr.Textbox(label="Status", interactive=False,
1324
+ value="πŸ”΅ IDLE")
 
1325
  with gr.Row():
1326
+ log_box = gr.Textbox(label="Logs", lines=14, interactive=False)
1327
  plt_out = gr.Plot(label="Loss Curve")
1328
+ train_ready = gr.Textbox(visible=False)
1329
 
1330
+ # ── Arena Tab ─────────────────────────────────────────────
1331
+ with gr.Tab("πŸ’¬ Negotiation Arena"):
 
 
 
 
 
 
 
 
 
 
 
1332
  with gr.Row():
1333
+
1334
+ # Left panel β€” setup & analysis
1335
+ with gr.Column(scale=1):
1336
+ gr.Markdown("### βš™οΈ Session Setup")
1337
+ arena_cat = gr.Dropdown(
1338
+ CATEGORIES, value="used_car", label="Category"
1339
+ )
1340
+ arena_item = gr.Textbox(
1341
+ value="2019 Honda Civic", label="Item Name"
1342
+ )
1343
+ arena_lp = gr.Number(value=18500, label="List Price ($)")
1344
+
1345
+ with gr.Row():
1346
+ arena_user_pty = gr.Radio(
1347
+ ["Buyer", "Seller"], value="Buyer", label="You are"
1348
+ )
1349
+ with gr.Row():
1350
+ arena_user_persona = gr.Dropdown(
1351
+ BUYER_PERSONAS, value="strategic",
1352
+ label="Your Persona"
1353
+ )
1354
+ arena_ai_persona = gr.Dropdown(
1355
+ SELLER_PERSONAS, value="firm",
1356
+ label="AI Persona"
1357
+ )
1358
+
1359
+ gr.Markdown("---\n### πŸ§‘ Buyer Bounds")
1360
+ buyer_budget = gr.Number(value=17000,
1361
+ label="Max Budget ($)")
1362
+ buyer_estimate = gr.Number(value=15500,
1363
+ label="Fair Value Estimate ($)")
1364
+ buyer_avoids = gr.Textbox(
1365
+ value="salvage,flood",
1366
+ label="Hard Avoids (comma list)"
1367
+ )
1368
+ buyer_must_have = gr.Textbox(
1369
+ value="bluetooth",
1370
+ label="Must-Have Features (comma list)"
1371
+ )
1372
+
1373
+ gr.Markdown("---\n### πŸ€– Seller Bounds")
1374
+ seller_reservation = gr.Number(
1375
+ value=15000, label="Seller Floor / Reservation ($)"
1376
+ )
1377
+ seller_urgency = gr.Dropdown(
1378
+ ["low", "medium", "high"], value="medium",
1379
+ label="Seller Urgency"
1380
+ )
1381
+
1382
+ reset_btn = gr.Button("πŸ”„ New Session", variant="secondary")
1383
+
1384
+ gr.Markdown("---\n### πŸ“Š Turn Analysis")
1385
+ arena_summary = gr.Textbox(
1386
+ label="Gap / ZOPA", interactive=False
1387
+ )
1388
+ arena_action = gr.Textbox(
1389
+ label="AI Action", interactive=False
1390
+ )
1391
+ arena_price = gr.Textbox(
1392
+ label="AI Price", interactive=False
1393
+ )
1394
+ arena_probs = gr.Textbox(
1395
+ label="Action Probabilities", interactive=False
1396
+ )
1397
+ inv_display = gr.Textbox(
1398
+ label="πŸ” Last Inventory Search",
1399
+ lines=5, interactive=False
1400
+ )
1401
+
1402
+ # Right panel β€” chat
1403
+ with gr.Column(scale=2):
1404
+ gr.Markdown("### πŸ—£οΈ Negotiation")
1405
+ chatbot = gr.Chatbot(height=520, label="Conversation")
1406
+ with gr.Row():
1407
+ arena_offer = gr.Number(value=16000,
1408
+ label="Your Offer ($)")
1409
+ arena_msg = gr.Textbox(
1410
+ placeholder="Type your message...",
1411
+ label="Your Message", scale=3
1412
+ )
1413
+ send_btn = gr.Button("Send β†’", variant="primary")
1414
+
1415
+ # ── Strategy Guides Tab ───────────────────────────────────
1416
+ with gr.Tab("πŸ“š Playbooks"):
1417
  with gr.Row():
1418
+ gr.Markdown(BUYER_GUIDE)
1419
+ gr.Markdown(SELLER_GUIDE)
1420
+
1421
+ # ── Inventory Browser Tab ─────────────────────────────────
1422
+ with gr.Tab("πŸ“¦ Inventory"):
1423
+ gr.Markdown(
1424
+ "### Current Inventory Database\n"
1425
+ "Plain text rows β€” term-frequency search, no vectors at rest."
1426
+ )
1427
+ inv_text = "\n".join(
1428
+ f"[{it['id']}] {it['category']} | {it['name']} | "
1429
+ f"{it['condition']} | Ask: ${it['ask_price']:,} | "
1430
+ f"Features: {it['features']}"
1431
+ for it in INVENTORY
1432
+ )
1433
+ gr.Textbox(
1434
+ value=inv_text, lines=30, interactive=False,
1435
+ label="Inventory (floor hidden from buyer-facing searches)"
1436
+ )
1437
+
1438
+ # ── State ─────────────────────────────────────────────────
1439
+ session_state = gr.State({})
1440
+
1441
+ def update_personas(party):
1442
+ if party == "Buyer":
1443
+ return (
1444
+ gr.Dropdown(choices=BUYER_PERSONAS, value="strategic"),
1445
+ gr.Dropdown(choices=SELLER_PERSONAS, value="firm"),
1446
+ )
1447
+ return (
1448
+ gr.Dropdown(choices=SELLER_PERSONAS, value="firm"),
1449
+ gr.Dropdown(choices=BUYER_PERSONAS, value="strategic"),
1450
+ )
1451
+
1452
+ arena_user_pty.change(
1453
+ update_personas,
1454
+ inputs=[arena_user_pty],
1455
+ outputs=[arena_user_persona, arena_ai_persona]
1456
+ )
1457
+
1458
+ tr_btn.click(
1459
+ run_training,
1460
+ inputs=[n_sessions, epochs, batch_size, lr],
1461
+ outputs=[status_box, log_box, plt_out, train_ready]
1462
+ )
1463
+
1464
+ send_btn.click(
1465
+ run_inference_turn,
1466
+ inputs=[
1467
+ session_state,
1468
+ arena_cat, arena_item, arena_lp,
1469
+ arena_offer, arena_msg,
1470
+ arena_user_pty, arena_user_persona, arena_ai_persona,
1471
+ buyer_budget, buyer_estimate,
1472
+ buyer_avoids, buyer_must_have,
1473
+ seller_reservation, seller_urgency,
1474
+ ],
1475
+ outputs=[
1476
+ session_state, chatbot, arena_summary,
1477
+ arena_action, arena_price, arena_probs, inv_display,
1478
+ ]
1479
+ )
1480
+
1481
+ reset_btn.click(
1482
+ reset_session,
1483
+ outputs=[
1484
+ session_state, chatbot, arena_summary,
1485
+ arena_action, arena_price, arena_probs, inv_display,
1486
+ ]
1487
+ )
1488
+
1489
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)