Spaces:
Sleeping
Sleeping
| # ================================================================ | |
| # ANP v5 | Bounded Multi-Agent Negotiation + Inventory Tool Use | |
| # Buyer bounds Β· Seller inventory context Β· Search action head | |
| # ZOPA tracking Β· Reservation prices Β· Ranked inventory matching | |
| # ================================================================ | |
| import os, time, math, random, uuid, gc | |
| from typing import List, Dict, Tuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from transformers import BertTokenizerFast | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| random.seed(42) | |
| torch.manual_seed(42) | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MSG_TYPES = ["offer","counter","accept","reject","exit","stall","search"] | |
| MSG2IDX = {m: i for i, m in enumerate(MSG_TYPES)} | |
| IDX2MSG = {i: m for m, i in MSG2IDX.items()} | |
| CATEGORIES = ["used_car","domain_name","freelance_design","saas_license", | |
| "electronics","bulk_groceries","consulting"] | |
| CAT2IDX = {c: i for i, c in enumerate(CATEGORIES)} | |
| BUYER_PERSONAS = ["aggressive","patient","skeptical","impulsive","strategic"] | |
| SELLER_PERSONAS = ["firm","motivated","anchoring","collaborative","desperate"] | |
| BPERSONA2IDX = {p: i for i, p in enumerate(BUYER_PERSONAS)} | |
| SPERSONA2IDX = {p: i for i, p in enumerate(SELLER_PERSONAS)} | |
| MAX_LEN = 96 | |
| D_MODEL = 384 | |
| N_HEADS = 6 | |
| N_LAYERS = 6 | |
| FFN_DIM = 1024 | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| print(f"Device: {DEVICE}") | |
| tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | |
| GLOBAL_MODEL = None | |
| # ================================================================ | |
| # INVENTORY DATABASE | |
| # ================================================================ | |
| def _make_inventory() -> List[Dict]: | |
| inv = [] | |
| templates = { | |
| "used_car": [ | |
| ("2018 Toyota Camry", "Good", 14500, 12800, | |
| "sunroof,bluetooth,low miles"), | |
| ("2019 Honda Civic", "Excellent", 18500, 16200, | |
| "one owner,new tires,clean title"), | |
| ("2020 Ford F-150", "Good", 28000, 24500, | |
| "tow package,crew cab,4WD"), | |
| ("2016 BMW 3 Series", "Fair", 16000, 13500, | |
| "sport package,leather,sunroof"), | |
| ("2021 Tesla Model 3", "Excellent", 38000, 35000, | |
| "autopilot,long range,premium audio"), | |
| ("2017 Chevy Silverado", "Good", 22000, 19000, | |
| "4WD,tow hitch,extended cab"), | |
| ("2015 Honda Accord", "Fair", 11000, 9200, | |
| "2 owners,new brakes,cloth seats"), | |
| ("2022 Toyota RAV4", "Excellent", 32000, 29500, | |
| "hybrid,AWD,apple carplay"), | |
| ], | |
| "electronics": [ | |
| ("MacBook Pro 14 M2", "Excellent", 1800, 1600, | |
| "16GB RAM,512GB SSD,AppleCare"), | |
| ("iPhone 14 Pro", "Good", 900, 780, | |
| "256GB,space black,minor scratches"), | |
| ("Sony 65in 4K TV", "Excellent", 750, 620, | |
| "OLED,smart tv,2 years old"), | |
| ("iPad Air Gen5", "Good", 550, 470, | |
| "wifi+cellular,pencil included"), | |
| ("Gaming PC RTX4070", "Excellent", 1400, 1200, | |
| "32GB RAM,1TB NVMe,water cooled"), | |
| ("DJI Mavic 3", "Good", 900, 780, | |
| "4K camera,3 batteries,case"), | |
| ], | |
| "domain_name": [ | |
| ("QuickLoan.io", "Premium", 12000, 9500, | |
| "fintech,4 years aged,high DA"), | |
| ("FreshMeals.com", "Good", 4500, 3800, | |
| "food delivery niche,aged 6yr"), | |
| ("TechPulse.net", "Good", 2200, 1800, | |
| "tech blog ready,clean history"), | |
| ("GreenHome.co", "Excellent", 5500, 4600, | |
| "eco niche,brandable,short"), | |
| ("RapidShip.io", "Premium", 8000, 6800, | |
| "logistics niche,exact match"), | |
| ], | |
| "freelance_design": [ | |
| ("Logo + Brand Kit", "Standard", 800, 650, | |
| "5 concepts,unlimited revisions,source files"), | |
| ("Website Redesign", "Premium", 3500, 2800, | |
| "5 pages,mobile,figma handoff"), | |
| ("UI/UX App Design", "Premium", 5000, 4200, | |
| "full wireframes,prototype,design system"), | |
| ("Social Media Pack", "Standard", 600, 480, | |
| "30 templates,brand colors,canva ready"), | |
| ("Pitch Deck Design", "Standard", 1200, 950, | |
| "20 slides,animations,2 revisions"), | |
| ], | |
| "saas_license": [ | |
| ("CRM Pro Annual", "Standard", 2400, 1900, | |
| "unlimited users,API access,support"), | |
| ("Analytics Suite", "Premium", 4800, 3900, | |
| "real-time,custom dashboards,export"), | |
| ("Project Mgmt Tool", "Standard", 1200, 980, | |
| "50 users,gantt,integrations"), | |
| ("Email Marketing Pro", "Standard", 960, 780, | |
| "100k contacts,automation,A/B"), | |
| ], | |
| "bulk_groceries": [ | |
| ("Organic Coffee 50lb", "Fresh", 420, 350, | |
| "single origin,roasted weekly,wholesale"), | |
| ("Olive Oil 5 Gal", "Premium", 280, 230, | |
| "extra virgin,cold press,Italian"), | |
| ("Almond Flour 25lb", "Fresh", 180, 145, | |
| "blanched,gluten free,bulk"), | |
| ("Protein Powder 20lb", "Good", 260, 210, | |
| "whey isolate,unflavored,NSF cert"), | |
| ], | |
| "consulting": [ | |
| ("SEO Audit + 90 Day Plan", "Standard", 1500, 1200, | |
| "technical+content,keyword research,monthly report"), | |
| ("Financial Model Build", "Premium", 3500, 2900, | |
| "3 statement,DCF,scenario analysis"), | |
| ("HR Policy Package", "Standard", 1800, 1450, | |
| "employee handbook,policies,compliance"), | |
| ("Marketing Strategy Q", "Premium", 4200, 3500, | |
| "market research,ICP,channel plan"), | |
| ], | |
| } | |
| for cat, items in templates.items(): | |
| for (name, cond, ask, res, feats) in items: | |
| inv.append({ | |
| "id": str(uuid.uuid4().hex[:8]), | |
| "category": cat, | |
| "name": name, | |
| "condition": cond, | |
| "ask_price": ask, | |
| "reservation_price": res, | |
| "features": feats, | |
| "notes": "", | |
| }) | |
| return inv | |
| INVENTORY: List[Dict] = _make_inventory() | |
| def search_inventory( | |
| category: str, | |
| max_price: float, | |
| min_price: float = 0, | |
| keywords: str = "", | |
| top_k: int = 4, | |
| avoids: str = "", | |
| ) -> List[Dict]: | |
| kws = [k.strip().lower() for k in keywords.split(",") if k.strip()] | |
| avd = [a.strip().lower() for a in avoids.split(",") if a.strip()] | |
| results = [] | |
| for item in INVENTORY: | |
| if item["category"] != category: | |
| continue | |
| if item["ask_price"] > max_price * 1.15: | |
| continue | |
| if item["ask_price"] < min_price: | |
| continue | |
| combined = f"{item['name']} {item['features']} {item['notes']}".lower() | |
| if any(av in combined for av in avd): | |
| continue | |
| kw_score = sum(1 for kw in kws if kw in combined) | |
| mid = ((max_price + min_price) / 2 | |
| if min_price > 0 else max_price * 0.8) | |
| price_dist = abs(item["ask_price"] - mid) / max(mid, 1) | |
| score = kw_score * 2 - price_dist | |
| results.append({**item, "_score": score}) | |
| results.sort(key=lambda x: x["_score"], reverse=True) | |
| return results[:top_k] | |
| def format_inventory_context( | |
| items: List[Dict], reveal_floor: bool = False | |
| ) -> str: | |
| if not items: | |
| return "No matching inventory found." | |
| lines = [] | |
| for it in items: | |
| line = (f"[{it['id']}] {it['name']} | {it['condition']} | " | |
| f"Ask: ${it['ask_price']:,} | Features: {it['features']}") | |
| if reveal_floor: | |
| line += f" | Floor: ${it['reservation_price']:,}" | |
| lines.append(line) | |
| return "\n".join(lines) | |
| # ================================================================ | |
| # TEMPLATES | |
| # ================================================================ | |
| TEMPLATES = { | |
| "seller_open_firm": [ | |
| "I've had this {item} listed and I'm firm at ${p:,.0f}. " | |
| "It's priced fairly for the condition.", | |
| "The market supports ${p:,.0f} for a {item} like this. " | |
| "I've done my research.", | |
| "Asking ${p:,.0f} for the {item}. I'm not in a rush β " | |
| "prefer not to negotiate far from that.", | |
| ], | |
| "seller_open_motivated": [ | |
| "I'm listing the {item} at ${p:,.0f} but open to reasonable offers. " | |
| "I'd like to move this quickly.", | |
| "Got this {item} up for ${p:,.0f}. " | |
| "Motivated to sell β make me an offer.", | |
| "Selling the {item} at ${p:,.0f}. " | |
| "I have flexibility if you're serious about buying today.", | |
| ], | |
| "seller_counter_hold": [ | |
| "I appreciate the offer but I can't go below ${p:,.0f}. " | |
| "That's really my floor.", | |
| "I hear you, but ${p:,.0f} is already a stretch. " | |
| "I have other interested buyers closer to asking.", | |
| "That doesn't quite work. I could come to ${p:,.0f} " | |
| "but that's genuinely as low as I go.", | |
| ], | |
| "seller_counter_concede": [ | |
| "Alright, I can meet you a bit closer β how does ${p:,.0f} sound?", | |
| "I've thought about it and I can work with ${p:,.0f} " | |
| "if we can close today.", | |
| "Let me split the difference with you. ${p:,.0f} β fair?", | |
| ], | |
| "seller_stall": [ | |
| "Let me think on that overnight. " | |
| "I want to make sure I'm not leaving too much on the table.", | |
| "I've got another showing tomorrow. " | |
| "Give me until then to decide if your number works.", | |
| "I need to check with my partner before I commit to that price.", | |
| ], | |
| "seller_reject": [ | |
| "I can't do that price β it doesn't cover what I have into this.", | |
| "That's too far from asking. I'd rather hold onto it.", | |
| "I appreciate you trying but that number doesn't work for me at all.", | |
| ], | |
| "seller_return_after_walkaway": [ | |
| "Hey, I've been thinking. The other buyer fell through β " | |
| "would you still do ${p:,.0f}?", | |
| "Circling back β other deal didn't pan out. " | |
| "If ${p:,.0f} is still on the table I'd like to make it work.", | |
| "The showing yesterday didn't go anywhere. " | |
| "I'm willing to revisit your ${p:,.0f}.", | |
| ], | |
| "seller_urgency": [ | |
| "Someone else is coming to look this weekend. " | |
| "If you want it at ${p:,.0f} I need to know by tomorrow.", | |
| "Just so you know I've got two other people interested. " | |
| "First right of refusal at ${p:,.0f}.", | |
| "My situation has changed and I need to close this week. " | |
| "${p:,.0f} only if we finalize today.", | |
| ], | |
| "seller_accept": [ | |
| "You know what, ${p:,.0f} works. Let's do it.", | |
| "Deal. ${p:,.0f} and it's yours.", | |
| "Alright, I'll take ${p:,.0f}. When can you pick it up?", | |
| ], | |
| "seller_exit": [ | |
| "I don't think we're going to get there on price. " | |
| "Good luck with your search.", | |
| "We're too far apart. I'm going to wait for a better offer.", | |
| "I appreciate the interest but this isn't going to work " | |
| "at your number.", | |
| ], | |
| "seller_search": [ | |
| "Let me check if I have something that better fits " | |
| "what you're describing.", | |
| "Hold on β I think I may have another option in my inventory " | |
| "that suits your needs.", | |
| "I want to make sure I'm showing you the best match. " | |
| "Let me pull some alternatives.", | |
| ], | |
| "buyer_open_aggressive": [ | |
| "I'll offer ${p:,.0f} and that's already above what I was " | |
| "planning to spend.", | |
| "I can do ${p:,.0f} cash today. " | |
| "I know that's low but I need to stay in my budget.", | |
| "First and best offer: ${p:,.0f}. " | |
| "I've seen similar {item}s go for less.", | |
| ], | |
| "buyer_open_strategic": [ | |
| "I've done some research on {item} values in this market. " | |
| "Based on comps I think ${p:,.0f} is fair.", | |
| "I'm genuinely interested. I'd like to start at ${p:,.0f} β " | |
| "I think there's a deal here.", | |
| "Serious buyer, ready to close fast. " | |
| "With that in mind, ${p:,.0f}.", | |
| ], | |
| "buyer_counter_nibble": [ | |
| "Getting closer. Can you do ${p:,.0f}? " | |
| "That's where I need to be to feel good about the deal.", | |
| "I'd say yes at ${p:,.0f}. " | |
| "Throw in the extras and I'll pull the trigger right now.", | |
| "If you can get to ${p:,.0f} I won't waste any more of " | |
| "your time β deal done.", | |
| ], | |
| "buyer_counter_hold": [ | |
| "I've thought about it and I'm still at ${p:,.0f}. " | |
| "That's genuinely what this is worth to me.", | |
| "My budget hasn't changed. ${p:,.0f} is the number.", | |
| "I hear you on the other buyers but ${p:,.0f} is my ceiling.", | |
| ], | |
| "buyer_stall": [ | |
| "I need to sleep on it. " | |
| "I'm also looking at a couple other options this week.", | |
| "Let me talk to my partner tonight and get back to you tomorrow.", | |
| "I'm not going to rush into this. Give me a day or two.", | |
| ], | |
| "buyer_walkaway": [ | |
| "I don't think we're going to get there. " | |
| "Thanks for your time β good luck with the sale.", | |
| "I'm going to pass. The price just doesn't work for what I need.", | |
| "Going to look at other options. " | |
| "If your price changes, feel free to reach out.", | |
| ], | |
| "buyer_return_after_walkaway": [ | |
| "Hey, been thinking about the {item} since we talked. " | |
| "Is ${p:,.0f} still the best you can do?", | |
| "Still have the {item} available? " | |
| "I might stretch to ${p:,.0f} if we can close quickly.", | |
| "Came back because I couldn't find anything comparable. " | |
| "Would you take ${p:,.0f}?", | |
| ], | |
| "buyer_accept": [ | |
| "Alright, you've got a deal at ${p:,.0f}.", | |
| "Fine, ${p:,.0f}. Let's stop going back and forth β I'll take it.", | |
| "Done. ${p:,.0f}. When can I come get it?", | |
| ], | |
| "buyer_reject": [ | |
| "That's still too high. I can't justify that price.", | |
| "No, that doesn't work. " | |
| "I'd need to see a significant move to reconsider.", | |
| "I'm out at that number. " | |
| "Not what the market is bearing right now.", | |
| ], | |
| "buyer_deadline": [ | |
| "I need to make a decision by end of day β " | |
| "can you give me your absolute best price?", | |
| "My budget approval expires Friday. " | |
| "If we agree on ${p:,.0f} right now I can move immediately.", | |
| "I have to make a call today. " | |
| "Meet me at ${p:,.0f} and we close this out.", | |
| ], | |
| "buyer_search": [ | |
| "Do you have anything else in this category that might " | |
| "work better for my needs?", | |
| "I'm not sure this is the right fit. " | |
| "Do you have other options I should look at?", | |
| "Before I decide, do you have alternatives β " | |
| "maybe different condition or price point?", | |
| ], | |
| } | |
| def _t(key: str, item: str = "", p: float = 0, | |
| avoid: str = "", must: str = "") -> str: | |
| return random.choice(TEMPLATES[key]).format( | |
| item=item, p=p, avoid=avoid, must=must | |
| ) | |
| # ================================================================ | |
| # STRATEGY PROFILES | |
| # ================================================================ | |
| BUYER_STRATEGY = { | |
| "aggressive": { | |
| "open_discount": (0.55, 0.68), "concession_rate": 0.015, | |
| "walkaway_prob": 0.35, "return_prob": 0.50, | |
| "patience": 3, "search_prob": 0.10, | |
| }, | |
| "patient": { | |
| "open_discount": (0.72, 0.82), "concession_rate": 0.025, | |
| "walkaway_prob": 0.15, "return_prob": 0.70, | |
| "patience": 8, "search_prob": 0.20, | |
| }, | |
| "skeptical": { | |
| "open_discount": (0.65, 0.75), "concession_rate": 0.018, | |
| "walkaway_prob": 0.28, "return_prob": 0.45, | |
| "patience": 5, "search_prob": 0.30, | |
| }, | |
| "impulsive": { | |
| "open_discount": (0.78, 0.88), "concession_rate": 0.040, | |
| "walkaway_prob": 0.10, "return_prob": 0.30, | |
| "patience": 2, "search_prob": 0.05, | |
| }, | |
| "strategic": { | |
| "open_discount": (0.62, 0.72), "concession_rate": 0.022, | |
| "walkaway_prob": 0.30, "return_prob": 0.65, | |
| "patience": 7, "search_prob": 0.25, | |
| }, | |
| } | |
| SELLER_STRATEGY = { | |
| "firm": { | |
| "min_discount": 0.93, "concession_rate": 0.008, | |
| "urgency_prob": 0.15, "return_prob": 0.30, "search_prob": 0.15, | |
| }, | |
| "motivated": { | |
| "min_discount": 0.82, "concession_rate": 0.030, | |
| "urgency_prob": 0.40, "return_prob": 0.60, "search_prob": 0.35, | |
| }, | |
| "anchoring": { | |
| "min_discount": 0.90, "concession_rate": 0.010, | |
| "urgency_prob": 0.25, "return_prob": 0.40, "search_prob": 0.20, | |
| }, | |
| "collaborative": { | |
| "min_discount": 0.86, "concession_rate": 0.022, | |
| "urgency_prob": 0.20, "return_prob": 0.55, "search_prob": 0.40, | |
| }, | |
| "desperate": { | |
| "min_discount": 0.75, "concession_rate": 0.045, | |
| "urgency_prob": 0.60, "return_prob": 0.75, "search_prob": 0.30, | |
| }, | |
| } | |
| # ================================================================ | |
| # DATA GENERATOR | |
| # ================================================================ | |
| def generate_sessions(n_sessions: int) -> List[Dict]: | |
| all_rows = [] | |
| for _ in range(int(n_sessions)): | |
| cat = random.choice(CATEGORIES) | |
| item = cat.replace("_", " ").title() | |
| lp = round(random.uniform(500, 25000), -1) | |
| sid = f"SYN-{uuid.uuid4().hex[:6].upper()}" | |
| b_persona = random.choice(BUYER_PERSONAS) | |
| s_persona = random.choice(SELLER_PERSONAS) | |
| bs = BUYER_STRATEGY[b_persona] | |
| ss = SELLER_STRATEGY[s_persona] | |
| turn = 0 | |
| rows = [] | |
| walked = False | |
| b_budget = lp * random.uniform(0.85, 1.05) | |
| b_estimate = lp * random.uniform(0.65, 0.80) | |
| s_reserve = lp * random.uniform(0.72, 0.88) | |
| def add(party, price, mtype, msg): | |
| nonlocal turn | |
| turn += 1 | |
| rows.append({ | |
| "session_id": sid, | |
| "turn_number": turn, | |
| "party": party, | |
| "category": cat, | |
| "item": item, | |
| "list_price": lp, | |
| "offer_price": round(price, 2), | |
| "msg_type": mtype, | |
| "message": msg, | |
| "buyer_persona": b_persona, | |
| "seller_persona": s_persona, | |
| "buyer_budget": b_budget, | |
| "buyer_estimate": b_estimate, | |
| "seller_reservation": s_reserve, | |
| }) | |
| sp = lp | |
| bp = round(lp * random.uniform(*bs["open_discount"]), -1) | |
| s_tmpl = ("seller_open_motivated" | |
| if s_persona in ["motivated", "desperate"] | |
| else "seller_open_firm") | |
| b_tmpl = ("buyer_open_aggressive" | |
| if b_persona == "aggressive" | |
| else "buyer_open_strategic") | |
| add(0, sp, "offer", _t(s_tmpl, item=item, p=sp)) | |
| add(1, bp, "counter", _t(b_tmpl, item=item, p=bp)) | |
| max_turns = random.randint(8, 24) | |
| prev_sp = sp | |
| prev_bp = bp | |
| stall_streak = 0 | |
| for rnd in range(max_turns): | |
| gap = sp - bp | |
| gap_pct = gap / lp if lp > 0 else 0 | |
| # Natural close | |
| if gap_pct < 0.03: | |
| fp = round((sp + bp) / 2, -1) | |
| if random.random() < 0.75: | |
| add(random.choice([0, 1]), fp, "accept", | |
| _t("seller_accept" | |
| if random.random() < 0.5 | |
| else "buyer_accept", p=fp)) | |
| break | |
| # ββ Seller turn βββββββββββββββββββββββββββββββββββ | |
| if random.random() < ss["search_prob"] and rnd > 1: | |
| add(0, sp, "search", _t("seller_search")) | |
| match_p = round(sp * random.uniform(0.88, 0.98), -1) | |
| add(0, match_p, "counter", | |
| f"I found something that might work better β " | |
| f"similar {item} at ${match_p:,.0f} with better " | |
| f"specs for your needs.") | |
| sp = match_p | |
| stall_streak = 0 | |
| elif random.random() < ss["urgency_prob"] and rnd > 1: | |
| add(0, sp, "stall", _t("seller_urgency", item=item, p=sp)) | |
| stall_streak += 1 | |
| elif gap_pct > 0.30: | |
| add(0, sp, "reject", _t("seller_reject")) | |
| elif prev_sp == sp and stall_streak < 2: | |
| add(0, sp, "stall", _t("seller_stall")) | |
| stall_streak += 1 | |
| else: | |
| concede_s = (ss["concession_rate"] * lp | |
| * random.uniform(0.5, 1.5)) | |
| sp = max(max(bp + gap * 0.15, sp - concede_s), s_reserve) | |
| sp = round(sp, -1) | |
| tmpl = ("seller_counter_concede" | |
| if concede_s > lp * 0.02 | |
| else "seller_counter_hold") | |
| add(0, sp, "counter", _t(tmpl, p=sp)) | |
| stall_streak = 0 | |
| prev_sp = sp | |
| gap = sp - bp | |
| # ββ Buyer turn ββββββββββββββββββββββββββββββββββββ | |
| concede_b = (bs["concession_rate"] * lp | |
| * random.uniform(0.5, 1.5)) | |
| if (random.random() < bs["search_prob"] | |
| and gap_pct > 0.12 and rnd > 1): | |
| add(1, bp, "search", _t("buyer_search")) | |
| new_bp = round(bp * random.uniform(1.01, 1.06), -1) | |
| add(1, new_bp, "counter", | |
| f"I looked at your alternatives β I could do " | |
| f"${new_bp:,.0f} for the right {item} with the " | |
| f"features I need.") | |
| bp = new_bp | |
| elif (not walked | |
| and random.random() < bs["walkaway_prob"] | |
| and rnd > 2): | |
| walked = True | |
| add(1, bp, "exit", _t("buyer_walkaway")) | |
| if random.random() < bs["return_prob"]: | |
| rp = round(bp * 1.04, -1) | |
| add(1, rp, "counter", | |
| _t("buyer_return_after_walkaway", | |
| item=item, p=rp)) | |
| bp = rp | |
| else: | |
| break | |
| elif rnd > bs["patience"] and random.random() < 0.30: | |
| bp = min(sp - gap * 0.1, bp + concede_b) | |
| bp = min(bp, b_budget) | |
| bp = round(bp, -1) | |
| add(1, bp, "counter", _t("buyer_deadline", p=bp)) | |
| elif gap_pct < 0.08 and random.random() < 0.40: | |
| add(1, bp, "counter", _t("buyer_counter_nibble", p=bp)) | |
| elif random.random() < 0.15: | |
| add(1, bp, "stall", _t("buyer_stall")) | |
| elif prev_bp == bp and random.random() < 0.35: | |
| add(1, bp, "counter", _t("buyer_counter_hold", p=bp)) | |
| else: | |
| bp = min(bp + concede_b, b_budget) | |
| bp = min(sp - gap * 0.15, bp) | |
| bp = round(bp, -1) | |
| add(1, bp, "counter", _t("buyer_counter_nibble", p=bp)) | |
| prev_bp = bp | |
| if gap / lp > 0.45: | |
| add(1, bp, "exit", _t("buyer_reject")) | |
| if random.random() < ss["return_prob"]: | |
| new_sp = round(sp * 0.94, -1) | |
| add(0, new_sp, "counter", | |
| _t("seller_return_after_walkaway", p=new_sp)) | |
| sp = new_sp | |
| else: | |
| break | |
| else: | |
| if (sp - bp) / lp < 0.08: | |
| fp = round((sp + bp) / 2, -1) | |
| add(random.choice([0, 1]), fp, "accept", | |
| _t("seller_accept", p=fp)) | |
| else: | |
| add(1, bp, "exit", _t("buyer_walkaway")) | |
| all_rows.extend(rows) | |
| return all_rows | |
| # ================================================================ | |
| # FEATURE EXTRACTION β all list guards in place | |
| # ================================================================ | |
| def extract_features(turns, idx, lp, | |
| b_budget=0, b_estimate=0, s_reserve=0): | |
| hist = turns[:idx] | |
| if len(hist) < 1: | |
| return [0.0] * 10 | |
| sp_prices = [r["offer_price"] for r in hist if int(r["party"]) == 0] | |
| bp_prices = [r["offer_price"] for r in hist if int(r["party"]) == 1] | |
| s_vel = ((sp_prices[-1] - sp_prices[0]) / lp) \ | |
| if len(sp_prices) > 1 else 0.0 | |
| b_vel = ((bp_prices[-1] - bp_prices[0]) / lp) \ | |
| if len(bp_prices) > 1 else 0.0 | |
| gap_r = ((sp_prices[-1] - bp_prices[-1]) / lp) \ | |
| if (sp_prices and bp_prices) else 1.0 | |
| s_con = sum( | |
| max(0, sp_prices[i-1] - sp_prices[i]) | |
| for i in range(1, len(sp_prices)) | |
| ) / lp if len(sp_prices) > 1 else 0.0 | |
| b_con = sum( | |
| max(0, bp_prices[i] - bp_prices[i-1]) | |
| for i in range(1, len(bp_prices)) | |
| ) / lp if len(bp_prices) > 1 else 0.0 | |
| stalls = (sum(1 for r in hist if r["msg_type"] == "stall") | |
| / max(len(hist), 1)) | |
| searches = (sum(1 for r in hist if r["msg_type"] == "search") | |
| / max(len(hist), 1)) | |
| # Bound-relative β guarded against empty lists | |
| budget_dist = min( | |
| (bp_prices[-1] - b_estimate) / max(b_budget - b_estimate, 1), 2.0 | |
| ) if (b_budget > 0 and bp_prices) else 0.0 | |
| floor_dist = min( | |
| (sp_prices[-1] - s_reserve) / max(lp - s_reserve, 1), 1.5 | |
| ) if (s_reserve > 0 and sp_prices) else 0.5 | |
| turns_norm = min(idx / 25.0, 1.0) | |
| return [ | |
| float(s_vel - b_vel), | |
| float(min(max(gap_r, 0.0), 2.0)), | |
| float(min(s_con, 2.0)), | |
| float(min(b_con, 2.0)), | |
| float(stalls), | |
| float(searches), | |
| float(budget_dist), | |
| float(floor_dist), | |
| float(turns_norm), | |
| 0.0, | |
| ] | |
| # ================================================================ | |
| # DATASET BUILDER β selective pin_memory (small tensors only) | |
| # ================================================================ | |
| def build_pinned_dataset(rows: List[Dict]) -> TensorDataset: | |
| sessions = {} | |
| for r in rows: | |
| sessions.setdefault(r["session_id"], []).append(r) | |
| (texts, party_l, cat_l, ofn_l, tn_l, | |
| msg_l, pt_l, bp_l, sp_l, mom_l) = ([] for _ in range(10)) | |
| for turns in sessions.values(): | |
| turns = sorted(turns, key=lambda x: int(x["turn_number"])) | |
| lp = float(turns[0]["list_price"]) | |
| if lp <= 0: | |
| continue | |
| b_bud = float(turns[0].get("buyer_budget", lp)) | |
| b_est = float(turns[0].get("buyer_estimate", lp * 0.75)) | |
| s_res = float(turns[0].get("seller_reservation", lp * 0.80)) | |
| for i in range(1, len(turns)): | |
| tgt = turns[i] | |
| recent = turns[max(0, i-3):i] | |
| text = " [SEP] ".join( | |
| f"{'S' if int(t['party'])==0 else 'B'}: {t['message']}" | |
| for t in recent | |
| ) | |
| mom = extract_features(turns, i, lp, b_bud, b_est, s_res) | |
| texts.append(text) | |
| party_l.append(int(tgt["party"])) | |
| cat_l.append(CAT2IDX.get(tgt["category"], 0)) | |
| ofn_l.append(min(float(tgt["offer_price"]) / lp, 3.0)) | |
| tn_l.append(min(int(tgt["turn_number"]) / 25.0, 1.0)) | |
| msg_l.append(MSG2IDX.get(tgt["msg_type"], 1)) | |
| pt_l.append(min(float(tgt["offer_price"]) / lp, 3.0)) | |
| bp_l.append(BPERSONA2IDX.get( | |
| tgt.get("buyer_persona", "patient"), 1)) | |
| sp_l.append(SPERSONA2IDX.get( | |
| tgt.get("seller_persona", "firm"), 0)) | |
| mom_l.append(mom) | |
| del sessions, rows | |
| gc.collect() | |
| n = len(texts) | |
| input_ids = torch.empty((n, MAX_LEN), dtype=torch.long) | |
| attn_mask = torch.empty((n, MAX_LEN), dtype=torch.long) | |
| for i in range(0, n, 20000): | |
| chunk = texts[i : i + 20000] | |
| enc = tokenizer( | |
| chunk, max_length=MAX_LEN, | |
| padding="max_length", truncation=True, | |
| return_tensors="pt" | |
| ) | |
| input_ids[i : i + 20000] = enc["input_ids"] | |
| attn_mask[i : i + 20000] = enc["attention_mask"] | |
| del texts | |
| gc.collect() | |
| tensors = dict( | |
| ids = input_ids, | |
| mask = attn_mask, | |
| pty = torch.tensor(party_l, dtype=torch.long), | |
| cat = torch.tensor(cat_l, dtype=torch.long), | |
| ofn = torch.tensor(ofn_l, dtype=torch.float), | |
| tn = torch.tensor(tn_l, dtype=torch.float), | |
| mt = torch.tensor(msg_l, dtype=torch.long), | |
| pt = torch.tensor(pt_l, dtype=torch.float), | |
| bp = torch.tensor(bp_l, dtype=torch.long), | |
| sp = torch.tensor(sp_l, dtype=torch.long), | |
| mom = torch.tensor(mom_l, dtype=torch.float), | |
| ) | |
| del party_l, cat_l, ofn_l, tn_l, msg_l, pt_l, bp_l, sp_l, mom_l | |
| gc.collect() | |
| # ββ Selective pin_memory ββββββββββββββββββββββββββββββββββ | |
| # ids + mask are ~400 MB each β pinning them causes the CUDA | |
| # driver to reserve matching GPU-side DMA staging buffers, | |
| # blowing VRAM before training even starts. | |
| # Only pin the small scalar tensors; they transfer instantly | |
| # and get the DMA benefit without the memory cost. | |
| if DEVICE.type == "cuda": | |
| SMALL_KEYS = {"pty","cat","ofn","tn","mt","pt","bp","sp","mom"} | |
| tensors = { | |
| k: (v.pin_memory() if k in SMALL_KEYS else v) | |
| for k, v in tensors.items() | |
| } | |
| return TensorDataset(*tensors.values()) | |
| # ================================================================ | |
| # MODEL | |
| # ================================================================ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d: int, max_len: int = 512): | |
| super().__init__() | |
| self.drop = nn.Dropout(0.1) | |
| pe = torch.zeros(max_len, d) | |
| pos = torch.arange(max_len).unsqueeze(1).float() | |
| div = torch.exp( | |
| torch.arange(0, d, 2).float() * (-math.log(10000.0) / d) | |
| ) | |
| pe[:, 0::2] = torch.sin(pos * div) | |
| pe[:, 1::2] = torch.cos(pos * div) | |
| self.register_buffer("pe", pe.unsqueeze(0)) | |
| def forward(self, x): | |
| return self.drop(x + self.pe[:, :x.size(1)]) | |
| class MomentumEncoder(nn.Module): | |
| def __init__(self, in_dim: int = 10, out_dim: int = 48): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, 64), nn.GELU(), | |
| nn.Linear(64, out_dim) | |
| ) | |
| def forward(self, x): return self.net(x) | |
| class NegotiationTransformer(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.emb = nn.Embedding(30522, D_MODEL, padding_idx=0) | |
| self.pos = PositionalEncoding(D_MODEL) | |
| enc_layer = nn.TransformerEncoderLayer( | |
| D_MODEL, N_HEADS, FFN_DIM, | |
| dropout=0.1, batch_first=True, norm_first=True | |
| ) | |
| self.encoder = nn.TransformerEncoder(enc_layer, N_LAYERS) | |
| self.p_emb = nn.Embedding(2, 32) | |
| self.c_emb = nn.Embedding(len(CATEGORIES), 64) | |
| self.bp_emb = nn.Embedding(len(BUYER_PERSONAS), 32) | |
| self.sp_emb = nn.Embedding(len(SELLER_PERSONAS), 32) | |
| self.mom_enc = MomentumEncoder(10, 48) | |
| total_ctx = D_MODEL + 32 + 64 + 32 + 32 + 48 + 2 | |
| self.fusion = nn.Sequential( | |
| nn.Linear(total_ctx, D_MODEL), nn.GELU(), nn.Dropout(0.1) | |
| ) | |
| self.msg_head = nn.Linear(D_MODEL, len(MSG_TYPES)) | |
| self.px_head = nn.Sequential( | |
| nn.Linear(D_MODEL, 128), nn.GELU(), | |
| nn.Linear(128, 1), nn.Softplus() | |
| ) | |
| def forward(self, ids, mask, party, cat, ofn, tn, bp, sp, mom): | |
| x = self.pos(self.emb(ids)) | |
| x = self.encoder(x, src_key_padding_mask=(mask == 0)) | |
| cls = x[:, 0] | |
| ctx = torch.cat([ | |
| cls, | |
| self.p_emb(party), | |
| self.c_emb(cat), | |
| self.bp_emb(bp), | |
| self.sp_emb(sp), | |
| self.mom_enc(mom), | |
| torch.stack([ofn, tn], dim=1), | |
| ], dim=1) | |
| f = self.fusion(ctx) | |
| return self.msg_head(f), self.px_head(f).squeeze(1) | |
| class AsymmetricNegotiationLoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # [offer, counter, accept, reject, exit, stall, search] | |
| self.seller_w = torch.tensor([1.0,1.0,1.5,1.2,1.3,0.8,1.1]) | |
| self.buyer_w = torch.tensor([1.0,1.0,1.3,1.0,1.2,0.9,1.2]) | |
| def forward(self, mt_logits, mt_targets, px_pred, px_targets, party): | |
| dev = mt_logits.device | |
| sw = self.seller_w.to(dev) | |
| bw = self.buyer_w.to(dev) | |
| loss_mt = torch.zeros(mt_logits.size(0), device=dev) | |
| sm = (party == 0) | |
| bm = (party == 1) | |
| if sm.any(): | |
| loss_mt[sm] = F.cross_entropy( | |
| mt_logits[sm], mt_targets[sm], | |
| weight=sw, reduction="none" | |
| ) | |
| if bm.any(): | |
| loss_mt[bm] = F.cross_entropy( | |
| mt_logits[bm], mt_targets[bm], | |
| weight=bw, reduction="none" | |
| ) | |
| return loss_mt.mean() + 0.5 * F.mse_loss(px_pred, px_targets) | |
| # ================================================================ | |
| # PLOT | |
| # ================================================================ | |
| def plot_curve(losses): | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| if losses: | |
| ax.plot(range(1, len(losses)+1), losses, "b-o", markersize=4) | |
| ax.set_title("Training Loss") | |
| else: | |
| ax.text(0.5, 0.5, "No data yet", | |
| ha="center", va="center", alpha=0.5) | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| # ================================================================ | |
| # TRAINING | |
| # ================================================================ | |
| def run_training(n_sessions, epochs, batch_size, lr): | |
| global GLOBAL_MODEL | |
| logs = [] | |
| def log(msg): | |
| ts = time.strftime("%H:%M:%S") | |
| line = f"[{ts}] {msg}" | |
| logs.append(line) | |
| if len(logs) > 20: | |
| logs.pop(0) | |
| print(line) | |
| return "\n".join(logs) | |
| try: | |
| batch_size = int(batch_size) | |
| log_txt = log(f"Generating {int(n_sessions):,} sessions...") | |
| yield "π‘ Generating...", log_txt, plot_curve([]), "β Needs Training" | |
| rows = generate_sessions(int(n_sessions)) | |
| log_txt = log(f"Generated {len(rows):,} rows. Building dataset...") | |
| yield "π‘ Tokenizing...", log_txt, plot_curve([]), "β Needs Training" | |
| dataset = build_pinned_dataset(rows) | |
| loader = DataLoader( | |
| dataset, batch_size=batch_size, | |
| shuffle=True, num_workers=0, | |
| pin_memory=False, drop_last=True | |
| ) | |
| total_batches = len(loader) | |
| log_txt = log(f"Dataset: {len(dataset):,} samples | " | |
| f"{total_batches} batches | bs={batch_size}") | |
| yield "π‘ Building model...", log_txt, plot_curve([]), "β Needs Training" | |
| model = NegotiationTransformer().to(DEVICE) | |
| crit = AsymmetricNegotiationLoss() | |
| if hasattr(torch, "compile") and DEVICE.type == "cuda": | |
| try: | |
| model = torch.compile(model, backend="cudagraphs") | |
| log_txt = log("torch.compile (cudagraphs) applied") | |
| except Exception as ce: | |
| log_txt = log(f"compile skipped: {ce}") | |
| opt = AdamW(model.parameters(), lr=float(lr), weight_decay=1e-2) | |
| sch = CosineAnnealingLR(opt, T_max=int(epochs)) | |
| scaler = torch.cuda.amp.GradScaler() | |
| losses = [] | |
| log_txt = log("π Training started") | |
| yield "π’ Training...", log_txt, plot_curve([]), "β Needs Training" | |
| for ep in range(int(epochs)): | |
| model.train() | |
| ep_loss = 0.0 | |
| t0 = time.time() | |
| for i, batch in enumerate(loader): | |
| (b_ids, b_mask, b_pty, b_cat, b_ofn, | |
| b_tn, b_mt, b_pt, b_bp, b_sp, b_mom) = [ | |
| t.to(DEVICE, non_blocking=True) for t in batch | |
| ] | |
| if i % 100 == 0: | |
| el = time.time() - t0 | |
| ms_b = (el / max(i, 1)) * 1000 | |
| status = (f"π’ Epoch {ep+1}/{int(epochs)} | " | |
| f"Batch {i}/{total_batches} | " | |
| f"{ms_b:.0f}ms/batch") | |
| log_txt = log(status) | |
| yield (status, log_txt, | |
| plot_curve(losses), "β Needs Training") | |
| opt.zero_grad(set_to_none=True) | |
| with torch.cuda.amp.autocast(): | |
| mt_logits, px_pred = model( | |
| b_ids, b_mask, b_pty, b_cat, | |
| b_ofn, b_tn, b_bp, b_sp, b_mom | |
| ) | |
| loss = crit(mt_logits, b_mt, px_pred, b_pt, b_pty) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(opt) | |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(opt) | |
| scaler.update() | |
| ep_loss += loss.item() | |
| sch.step() | |
| avg = ep_loss / max(total_batches, 1) | |
| et = time.time() - t0 | |
| losses.append(avg) | |
| log_txt = log( | |
| f"Epoch {ep+1}/{int(epochs)} done β " | |
| f"loss: {avg:.4f} | {et:.1f}s | " | |
| f"{et/total_batches*1000:.0f}ms/batch" | |
| ) | |
| yield (f"π’ Epoch {ep+1} done", log_txt, | |
| plot_curve(losses), "β Needs Training") | |
| model.eval() | |
| GLOBAL_MODEL = model | |
| log_txt = log("β Training complete.") | |
| yield "π΅ Complete", log_txt, plot_curve(losses), "β Ready" | |
| except Exception as e: | |
| import traceback | |
| log_txt = log(f"ERROR: {e}\n{traceback.format_exc()}") | |
| yield "π΄ ERROR", log_txt, plot_curve([]), "β Needs Training" | |
| # ================================================================ | |
| # INFERENCE ENGINE | |
| # ================================================================ | |
| def _build_message(msg_type, price, item, | |
| is_buyer, persona, inv_context=""): | |
| p = price | |
| if msg_type == "search": | |
| return _t("buyer_search") if is_buyer else _t("seller_search") | |
| if is_buyer: | |
| m = { | |
| "offer": _t("buyer_open_strategic", item=item, p=p), | |
| "counter": _t("buyer_counter_nibble", p=p), | |
| "accept": _t("buyer_accept", p=p), | |
| "reject": _t("buyer_reject"), | |
| "exit": _t("buyer_walkaway"), | |
| "stall": _t("buyer_stall"), | |
| } | |
| if persona == "aggressive": | |
| m["offer"] = _t("buyer_open_aggressive", item=item, p=p) | |
| m["counter"] = _t("buyer_counter_hold", p=p) | |
| else: | |
| m = { | |
| "offer": _t("seller_open_firm", item=item, p=p), | |
| "counter": _t("seller_counter_hold", p=p), | |
| "accept": _t("seller_accept", p=p), | |
| "reject": _t("seller_reject"), | |
| "exit": _t("seller_exit"), | |
| "stall": _t("seller_stall"), | |
| } | |
| if persona in ["motivated", "desperate"]: | |
| m["offer"] = _t("seller_open_motivated", item=item, p=p) | |
| m["counter"] = _t("seller_counter_concede", p=p) | |
| if inv_context: | |
| m["search"] = ( | |
| "Let me check my inventory...\n" | |
| f"{inv_context}\nWould any of these work for you?" | |
| ) | |
| return m.get(msg_type, f"{msg_type} @ ${p:,.2f}") | |
| def run_inference_turn( | |
| session_state, | |
| category, item, | |
| list_price, user_price, user_message, | |
| user_party, user_persona, ai_persona, | |
| buyer_budget, buyer_estimate, | |
| buyer_avoids, buyer_must_have, | |
| seller_reservation, seller_urgency, | |
| ): | |
| if GLOBAL_MODEL is None: | |
| return (session_state, | |
| session_state.get("history_ui", []), | |
| "Model not trained.", "", "", "", "") | |
| lp = float(list_price) | |
| is_user_buyer = (user_party == "Buyer") | |
| ai_party_int = 0 if is_user_buyer else 1 | |
| # ββ Initialise session ββββββββββββββββββββββββββββββββββββ | |
| if not session_state.get("started"): | |
| init_bp = (float(buyer_estimate) | |
| if float(buyer_estimate) > 0 | |
| else round(lp * 0.75, -1)) | |
| session_state = { | |
| "started": True, | |
| "turn": 0, | |
| "sp": lp, | |
| "bp": init_bp, | |
| "history": [], | |
| "history_ui": [], | |
| "status": "active", | |
| "inv_context": "", | |
| } | |
| if session_state["status"] != "active": | |
| return (session_state, session_state["history_ui"], | |
| "Session ended β click New Session to restart.", | |
| "", "", "", "") | |
| history = session_state["history"] | |
| history_ui = session_state["history_ui"] | |
| sp = float(session_state["sp"]) | |
| bp = float(session_state["bp"]) | |
| turn = session_state["turn"] | |
| b_bud = float(buyer_budget) if float(buyer_budget) > 0 else lp | |
| b_est = float(buyer_estimate) if float(buyer_estimate) > 0 else lp * 0.75 | |
| s_res = float(seller_reservation) if float(seller_reservation) > 0 else lp * 0.80 | |
| # ββ Record user turn ββββββββββββββββββββββββββββββββββββββ | |
| u_int = 1 if is_user_buyer else 0 | |
| history.append({ | |
| "party": u_int, | |
| "message": user_message, | |
| "offer_price": float(user_price), | |
| "msg_type": "counter", | |
| "turn_number": turn + 1, | |
| }) | |
| history_ui.append(( | |
| f"{'π§ You (Buyer)' if is_user_buyer else 'π§ You (Seller)'}" | |
| f" [${float(user_price):,.0f}]: {user_message}", | |
| None | |
| )) | |
| turn += 1 | |
| if is_user_buyer: | |
| bp = float(user_price) | |
| else: | |
| sp = float(user_price) | |
| # ββ Build momentum features βββββββββββββββββββββββββββββββ | |
| sp_prices = [r["offer_price"] for r in history if int(r["party"]) == 0] | |
| bp_prices = [r["offer_price"] for r in history if int(r["party"]) == 1] | |
| s_vel = ((sp_prices[-1]-sp_prices[0])/lp) if len(sp_prices)>1 else 0.0 | |
| b_vel = ((bp_prices[-1]-bp_prices[0])/lp) if len(bp_prices)>1 else 0.0 | |
| gap_r = ((sp - bp) / lp) if sp > bp else 0.0 | |
| stalls = (sum(1 for r in history if r["msg_type"] == "stall") | |
| / max(len(history), 1)) | |
| srch = (sum(1 for r in history if r["msg_type"] == "search") | |
| / max(len(history), 1)) | |
| b_dist = min((bp - b_est) / max(b_bud - b_est, 1), 2.0) \ | |
| if (b_bud > 0 and bp_prices) else 0.0 | |
| f_dist = min((sp - s_res) / max(lp - s_res, 1), 1.5) \ | |
| if (s_res > 0 and sp_prices) else 0.5 | |
| mom = [ | |
| float(s_vel - b_vel), | |
| float(min(max(gap_r, 0.0), 2.0)), | |
| 0.0, 0.0, | |
| float(stalls), float(srch), | |
| float(b_dist), float(f_dist), | |
| float(min(turn / 25.0, 1.0)), | |
| 0.0, | |
| ] | |
| # ββ Build text context ββββββββββββββββββββββββββββββββββββ | |
| inv_ctx = session_state.get("inv_context", "") | |
| recent = history[-3:] | |
| text = " [SEP] ".join( | |
| f"{'S' if int(r['party'])==0 else 'B'}: {r['message']}" | |
| for r in recent | |
| ) | |
| if inv_ctx: | |
| text = f"[INV: {inv_ctx[:120]}] " + text | |
| enc = tokenizer( | |
| text, max_length=MAX_LEN, | |
| padding="max_length", truncation=True, | |
| return_tensors="pt" | |
| ) | |
| dev = DEVICE | |
| ai_pty_t = torch.tensor([ai_party_int], dtype=torch.long).to(dev) | |
| cat_t = torch.tensor([CAT2IDX.get(category, 0)], | |
| dtype=torch.long).to(dev) | |
| ofn_t = torch.tensor([min(float(user_price)/lp, 3.0)], | |
| dtype=torch.float).to(dev) | |
| tn_t = torch.tensor([min(turn/25.0, 1.0)], | |
| dtype=torch.float).to(dev) | |
| bp_idx = BPERSONA2IDX.get( | |
| user_persona if is_user_buyer else ai_persona, 1) | |
| sp_idx = SPERSONA2IDX.get( | |
| ai_persona if is_user_buyer else user_persona, 0) | |
| bp_t = torch.tensor([bp_idx], dtype=torch.long).to(dev) | |
| sp_t = torch.tensor([sp_idx], dtype=torch.long).to(dev) | |
| mom_t = torch.tensor([mom], dtype=torch.float).to(dev) | |
| with torch.no_grad(): | |
| mt_logits, px = GLOBAL_MODEL( | |
| enc["input_ids"].to(dev), | |
| enc["attention_mask"].to(dev), | |
| ai_pty_t, cat_t, ofn_t, tn_t, bp_t, sp_t, mom_t | |
| ) | |
| mt_idx = mt_logits.argmax(dim=1).item() | |
| msg_type = IDX2MSG[mt_idx] | |
| ai_price = round(float(px.item()) * lp, 2) | |
| # ββ Clamp AI price to valid range βββββββββββββββββββββββββ | |
| if ai_party_int == 0: # AI is seller | |
| ai_price = max(ai_price, s_res * 1.005) | |
| ai_price = min(ai_price, lp * 1.05) | |
| sp = ai_price | |
| else: # AI is buyer | |
| ai_price = min(ai_price, b_bud) | |
| ai_price = min(ai_price, sp * 0.99) | |
| ai_price = max(ai_price, lp * 0.25) | |
| bp = ai_price | |
| # ββ Execute inventory search if triggered βββββββββββββββββ | |
| inv_context_text = "" | |
| if msg_type == "search": | |
| if ai_party_int == 0: # Seller searches for buyer | |
| results = search_inventory( | |
| category = category, | |
| max_price = b_bud if b_bud > 0 else lp * 1.1, | |
| min_price = b_est * 0.8 if b_est > 0 else 0, | |
| keywords = buyer_must_have, | |
| avoids = buyer_avoids, | |
| top_k = 3, | |
| ) | |
| inv_context_text = format_inventory_context( | |
| results, reveal_floor=True | |
| ) | |
| else: # Buyer searches seller inventory | |
| results = search_inventory( | |
| category = category, | |
| max_price = b_bud if b_bud > 0 else lp, | |
| keywords = buyer_must_have, | |
| avoids = buyer_avoids, | |
| top_k = 3, | |
| ) | |
| inv_context_text = format_inventory_context( | |
| results, reveal_floor=False | |
| ) | |
| session_state["inv_context"] = inv_context_text | |
| # ββ Build AI message ββββββββββββββββββββββββββββββββββββββ | |
| ai_msg = _build_message( | |
| msg_type, ai_price, item, | |
| not is_user_buyer, ai_persona, | |
| inv_context_text | |
| ) | |
| if msg_type == "search" and inv_context_text: | |
| ai_msg += (f"\n\nπ¦ **Inventory Results:**\n" | |
| f"```\n{inv_context_text}\n```") | |
| history.append({ | |
| "party": ai_party_int, | |
| "message": ai_msg, | |
| "offer_price": ai_price, | |
| "msg_type": msg_type, | |
| "turn_number": turn + 1, | |
| }) | |
| ai_label = (f"π€ AI ({'Seller' if ai_party_int==0 else 'Buyer'}) " | |
| f"[{ai_persona}]") | |
| history_ui.append((None, f"{ai_label} [${ai_price:,.0f}]: {ai_msg}")) | |
| turn += 1 | |
| # ββ ZOPA ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| zopa = bp - s_res | |
| zopa_str = (f"β ZOPA: +${zopa:,.0f} (deal zone exists)" | |
| if zopa > 0 | |
| else f"β ZOPA: ${zopa:,.0f} (no overlap yet)") | |
| # ββ Terminal check ββββββββββββββββββββββββββββββββββββββββ | |
| status = "active" | |
| if msg_type == "accept": | |
| status = "closed" | |
| history_ui.append( | |
| (None, f"β **DEAL CLOSED at ${ai_price:,.0f}**") | |
| ) | |
| elif msg_type == "exit": | |
| status = "ended" | |
| history_ui.append((None, "β Negotiation ended")) | |
| probs = F.softmax(mt_logits, dim=1).squeeze().tolist() | |
| prob_str = " | ".join( | |
| f"{MSG_TYPES[i]}: {probs[i]:.2f}" for i in range(len(MSG_TYPES)) | |
| ) | |
| gap_pct = abs(sp - bp) / lp * 100 | |
| summary = (f"Turn {turn} | Gap: {gap_pct:.1f}% | " | |
| f"Seller: ${sp:,.0f} | Buyer: ${bp:,.0f} | {zopa_str}") | |
| session_state.update({ | |
| "turn": turn, | |
| "sp": sp, | |
| "bp": bp, | |
| "history": history, | |
| "history_ui": history_ui, | |
| "status": status, | |
| }) | |
| return (session_state, history_ui, summary, | |
| msg_type, f"${ai_price:,.2f}", prob_str, inv_context_text) | |
| def reset_session(): | |
| return {}, [], "Session reset.", "", "", "", "" | |
| # ================================================================ | |
| # STRATEGY GUIDES | |
| # ================================================================ | |
| BUYER_GUIDE = """### π Buyer Playbook | |
| **Bounds to set before starting:** | |
| - **Budget** β your true ceiling. Encoded as soft penalty, not hard wall. | |
| - **Estimate** β fair value anchor. Sets your opening offer range. | |
| - **Must-have features** β filters inventory search. e.g. *bluetooth, low miles* | |
| - **Hard avoids** β instant deal-breakers. e.g. *salvage title, high mileage* | |
| **Tactics the model trains on:** | |
| - π΄ Aggressive open at 55-65% of ask | |
| - πͺ Walk away at turn 3-4, return with prior offer | |
| - π Trigger search when gap > 12%: *"Do you have anything else?"* | |
| - β° Deadline pressure after patience threshold | |
| - πͺ Nibble for extras when gap < 8% | |
| - π€ Strategic persona: cite comps, build rapport""" | |
| SELLER_GUIDE = """### π Seller Playbook | |
| **Bounds to set before starting:** | |
| - **Reservation price** β private floor. Model NEVER accepts below this. | |
| - **Urgency** β high urgency raises concession rate and search frequency. | |
| - **Inventory** β pre-loaded. Searched when buyer asks for alternatives. | |
| **Tactics the model trains on:** | |
| - β Open 15-20% above target | |
| - π₯ Social proof: *"Two other buyers this weekend"* | |
| - π Proactively search inventory when buyer signals dissatisfaction | |
| - β° Urgency close: *"Need to close by Friday"* | |
| - π Return after walkaway with small concession | |
| - π Shrinking concessions signal approaching floor""" | |
| # ================================================================ | |
| # UI | |
| # ================================================================ | |
| with gr.Blocks(title="ANP v5 | Bounded Negotiation", | |
| theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# ANP v5 β Bounded Negotiation Engine\n" | |
| "Buyer bounds Β· Seller reservation Β· Inventory tool use Β· " | |
| "ZOPA tracking Β· Persona conditioning" | |
| ) | |
| # ββ Training Tab ββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("ποΈ Training"): | |
| with gr.Row(): | |
| n_sessions = gr.Number(value=20000, label="Sessions") | |
| epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs") | |
| batch_size = gr.Slider(64, 1024, value=512, step=64, | |
| label="Batch Size") | |
| lr = gr.Number(value=3e-4, label="LR") | |
| tr_btn = gr.Button("π Train", variant="primary") | |
| status_box = gr.Textbox(label="Status", interactive=False, | |
| value="π΅ IDLE") | |
| with gr.Row(): | |
| log_box = gr.Textbox(label="Logs", lines=14, interactive=False) | |
| plt_out = gr.Plot(label="Loss Curve") | |
| train_ready = gr.Textbox(visible=False) | |
| # ββ Arena Tab βββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π¬ Negotiation Arena"): | |
| with gr.Row(): | |
| # Left panel β setup & analysis | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Session Setup") | |
| arena_cat = gr.Dropdown( | |
| CATEGORIES, value="used_car", label="Category" | |
| ) | |
| arena_item = gr.Textbox( | |
| value="2019 Honda Civic", label="Item Name" | |
| ) | |
| arena_lp = gr.Number(value=18500, label="List Price ($)") | |
| with gr.Row(): | |
| arena_user_pty = gr.Radio( | |
| ["Buyer", "Seller"], value="Buyer", label="You are" | |
| ) | |
| with gr.Row(): | |
| arena_user_persona = gr.Dropdown( | |
| BUYER_PERSONAS, value="strategic", | |
| label="Your Persona" | |
| ) | |
| arena_ai_persona = gr.Dropdown( | |
| SELLER_PERSONAS, value="firm", | |
| label="AI Persona" | |
| ) | |
| gr.Markdown("---\n### π§ Buyer Bounds") | |
| buyer_budget = gr.Number(value=17000, | |
| label="Max Budget ($)") | |
| buyer_estimate = gr.Number(value=15500, | |
| label="Fair Value Estimate ($)") | |
| buyer_avoids = gr.Textbox( | |
| value="salvage,flood", | |
| label="Hard Avoids (comma list)" | |
| ) | |
| buyer_must_have = gr.Textbox( | |
| value="bluetooth", | |
| label="Must-Have Features (comma list)" | |
| ) | |
| gr.Markdown("---\n### π€ Seller Bounds") | |
| seller_reservation = gr.Number( | |
| value=15000, label="Seller Floor / Reservation ($)" | |
| ) | |
| seller_urgency = gr.Dropdown( | |
| ["low", "medium", "high"], value="medium", | |
| label="Seller Urgency" | |
| ) | |
| reset_btn = gr.Button("π New Session", variant="secondary") | |
| gr.Markdown("---\n### π Turn Analysis") | |
| arena_summary = gr.Textbox( | |
| label="Gap / ZOPA", interactive=False | |
| ) | |
| arena_action = gr.Textbox( | |
| label="AI Action", interactive=False | |
| ) | |
| arena_price = gr.Textbox( | |
| label="AI Price", interactive=False | |
| ) | |
| arena_probs = gr.Textbox( | |
| label="Action Probabilities", interactive=False | |
| ) | |
| inv_display = gr.Textbox( | |
| label="π Last Inventory Search", | |
| lines=5, interactive=False | |
| ) | |
| # Right panel β chat | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π£οΈ Negotiation") | |
| chatbot = gr.Chatbot(height=520, label="Conversation") | |
| with gr.Row(): | |
| arena_offer = gr.Number(value=16000, | |
| label="Your Offer ($)") | |
| arena_msg = gr.Textbox( | |
| placeholder="Type your message...", | |
| label="Your Message", scale=3 | |
| ) | |
| send_btn = gr.Button("Send β", variant="primary") | |
| # ββ Strategy Guides Tab βββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Playbooks"): | |
| with gr.Row(): | |
| gr.Markdown(BUYER_GUIDE) | |
| gr.Markdown(SELLER_GUIDE) | |
| # ββ Inventory Browser Tab βββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π¦ Inventory"): | |
| gr.Markdown( | |
| "### Current Inventory Database\n" | |
| "Plain text rows β term-frequency search, no vectors at rest." | |
| ) | |
| inv_text = "\n".join( | |
| f"[{it['id']}] {it['category']} | {it['name']} | " | |
| f"{it['condition']} | Ask: ${it['ask_price']:,} | " | |
| f"Features: {it['features']}" | |
| for it in INVENTORY | |
| ) | |
| gr.Textbox( | |
| value=inv_text, lines=30, interactive=False, | |
| label="Inventory (floor hidden from buyer-facing searches)" | |
| ) | |
| # ββ State βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| session_state = gr.State({}) | |
| def update_personas(party): | |
| if party == "Buyer": | |
| return ( | |
| gr.Dropdown(choices=BUYER_PERSONAS, value="strategic"), | |
| gr.Dropdown(choices=SELLER_PERSONAS, value="firm"), | |
| ) | |
| return ( | |
| gr.Dropdown(choices=SELLER_PERSONAS, value="firm"), | |
| gr.Dropdown(choices=BUYER_PERSONAS, value="strategic"), | |
| ) | |
| arena_user_pty.change( | |
| update_personas, | |
| inputs=[arena_user_pty], | |
| outputs=[arena_user_persona, arena_ai_persona] | |
| ) | |
| tr_btn.click( | |
| run_training, | |
| inputs=[n_sessions, epochs, batch_size, lr], | |
| outputs=[status_box, log_box, plt_out, train_ready] | |
| ) | |
| send_btn.click( | |
| run_inference_turn, | |
| inputs=[ | |
| session_state, | |
| arena_cat, arena_item, arena_lp, | |
| arena_offer, arena_msg, | |
| arena_user_pty, arena_user_persona, arena_ai_persona, | |
| buyer_budget, buyer_estimate, | |
| buyer_avoids, buyer_must_have, | |
| seller_reservation, seller_urgency, | |
| ], | |
| outputs=[ | |
| session_state, chatbot, arena_summary, | |
| arena_action, arena_price, arena_probs, inv_display, | |
| ] | |
| ) | |
| reset_btn.click( | |
| reset_session, | |
| outputs=[ | |
| session_state, chatbot, arena_summary, | |
| arena_action, arena_price, arena_probs, inv_display, | |
| ] | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |