everydaytok commited on
Commit
8376a9f
Β·
verified Β·
1 Parent(s): e66f339

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ================================================================
2
+ # ANP Model | HF Free Tier (16GB CPU) | Background Training Daemon
3
+ # ================================================================
4
+ import os, time, math, random, uuid, threading
5
+ from typing import List, Dict
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.optim import AdamW
12
+ from torch.optim.lr_scheduler import CosineAnnealingLR
13
+ from transformers import BertTokenizerFast
14
+
15
+ import gradio as gr
16
+ import matplotlib
17
+ matplotlib.use("Agg")
18
+ import matplotlib.pyplot as plt
19
+
20
+ random.seed(42)
21
+ torch.manual_seed(42)
22
+
23
+ # ── Config & Globals ──────────────────────────────────────────
24
+ DEVICE = torch.device("cpu") # HF Free tier is CPU
25
+ MSG_TYPES = ["offer", "counter", "accept", "reject", "exit", "stall"]
26
+ MSG2IDX = {m: i for i, m in enumerate(MSG_TYPES)}
27
+ IDX2MSG = {i: m for m, i in MSG2IDX.items()}
28
+ CATEGORIES = ["used_car","domain_name","freelance_design","saas_license","electronics","bulk_groceries","consulting"]
29
+ CAT2IDX = {c: i for i, c in enumerate(CATEGORIES)}
30
+
31
+ MAX_LEN = 256
32
+ D_MODEL = 384
33
+ N_HEADS = 6
34
+ N_LAYERS = 6
35
+ FFN_DIM = 1024
36
+
37
+ print("Loading tokenizer...")
38
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
39
+
40
+ # ── Thread-Safe State Manager ─────────────────────────────────
41
+ class TrainingState:
42
+ def __init__(self):
43
+ self.lock = threading.Lock()
44
+ self.is_running = False
45
+ self.current_epoch = 0
46
+ self.total_epochs = 0
47
+ self.batch_progress = ""
48
+ self.logs = []
49
+ self.losses = []
50
+ self.model_ready = False
51
+
52
+ def log(self, msg: str):
53
+ with self.lock:
54
+ ts = time.strftime("%H:%M:%S")
55
+ self.logs.append(f"[{ts}] {msg}")
56
+ if len(self.logs) > 50: # Keep dashboard clean
57
+ self.logs.pop(0)
58
+ print(msg)
59
+
60
+ STATE = TrainingState()
61
+ GLOBAL_MODEL = None # Holds the model in memory for inference
62
+
63
+ # ── Synthetic Data Generator ──────────────────────────────────
64
+ def generate_sessions(n_sessions: int) -> List[Dict]:
65
+ """Generates synthetic negotiation data quickly in memory."""
66
+ all_rows = []
67
+ cats = list(CATEGORIES)
68
+
69
+ # Simple templates for generator (training text)
70
+ _SO = ["{item} for sale. Asking ${p:,.0f}.", "Listing {item} at ${p:,.0f}."]
71
+ _SC = ["Best I can do is ${p:,.0f}.", "Can come down to ${p:,.0f}."]
72
+ _SS = ["Let me think about it.", "Need to check with my partner."]
73
+ _SA = ["Deal at ${p:,.0f}.", "Agreed. ${p:,.0f}."]
74
+ _BC = ["Offering ${p:,.0f}.", "${p:,.0f} is my ceiling."]
75
+ _BE = ["Too far apart. Going to pass.", "Price doesn't work for me."]
76
+
77
+ def _t(templates, item="", p=0):
78
+ return random.choice(templates).format(item=item, p=p)
79
+
80
+ for _ in range(n_sessions):
81
+ cat = random.choice(cats)
82
+ item = f"Generic {cat} Item"
83
+ lp = round(random.uniform(500, 10000), -1)
84
+ sid = f"SYN-{uuid.uuid4().hex[:6].upper()}"
85
+ turn = 0
86
+ session_rows = []
87
+
88
+ def add(party, price, mtype, msg):
89
+ nonlocal turn
90
+ turn += 1
91
+ session_rows.append({
92
+ "session_id": sid, "turn_number": turn, "party": party,
93
+ "category": cat, "item": item, "list_price": lp,
94
+ "offer_price": price, "msg_type": mtype, "message": msg
95
+ })
96
+
97
+ sp = lp
98
+ bp = round(lp * random.uniform(0.6, 0.8), -1)
99
+
100
+ add(0, sp, "offer", _t(_SO, item=item, p=sp))
101
+ add(1, bp, "counter", _t(_BC, p=bp))
102
+
103
+ target = random.choice(["accepted", "abandoned", "rejected"])
104
+ for _ in range(random.randint(2, 6)):
105
+ gap = sp - bp
106
+ if target == "accepted" and (gap / lp) < 0.05:
107
+ final_p = round((sp + bp) / 2, -1)
108
+ add(0 if random.random() < 0.5 else 1, final_p, "accept", _t(_SA, p=final_p))
109
+ break
110
+ if target == "abandoned" and random.random() < 0.2:
111
+ add(0, sp, "stall", _t(_SS))
112
+ add(1, bp, "exit", _t(_BE))
113
+ break
114
+
115
+ sp = max(bp + gap * 0.3, sp - lp * random.uniform(0.02, 0.05))
116
+ sp = round(sp, -1)
117
+ add(0, sp, "counter", _t(_SC, p=sp))
118
+
119
+ gap = sp - bp
120
+ if target == "accepted" and (gap / lp) < 0.05:
121
+ final_p = round((sp + bp) / 2, -1)
122
+ add(1, final_p, "accept", _t(_SA, p=final_p))
123
+ break
124
+
125
+ bp = min(sp - gap * 0.3, bp + lp * random.uniform(0.02, 0.05))
126
+ bp = round(bp, -1)
127
+ add(1, bp, "counter", _t(_BC, p=bp))
128
+ else:
129
+ add(1, bp, "exit", _t(_BE))
130
+
131
+ all_rows.extend(session_rows)
132
+ return all_rows
133
+
134
+ # ── Dataset & Model ───────────────────────────────────────────
135
+ class NegotiationDataset(Dataset):
136
+ def __init__(self, rows: List[Dict]):
137
+ self.samples = []
138
+ sessions = {}
139
+ for r in rows:
140
+ sessions.setdefault(r["session_id"], []).append(r)
141
+
142
+ for turns in sessions.values():
143
+ turns = sorted(turns, key=lambda x: int(x["turn_number"]))
144
+ lp = float(turns[0]["list_price"])
145
+ if lp <= 0: continue
146
+
147
+ for i in range(1, len(turns)):
148
+ hist = turns[:i]
149
+ tgt = turns[i]
150
+ text = " [SEP] ".join(f"{'Seller' if t['party']==0 else 'Buyer'}: {t['message']}" for t in hist)
151
+ self.samples.append({
152
+ "text": text,
153
+ "party": int(tgt["party"]),
154
+ "category": CAT2IDX.get(tgt["category"], 0),
155
+ "ofn": min(float(tgt["offer_price"]) / lp, 3.0),
156
+ "tn": min(int(tgt["turn_number"]) / 20.0, 1.0),
157
+ "msg_type": MSG2IDX.get(tgt["msg_type"], 1),
158
+ "price_t": min(float(tgt["offer_price"]) / lp, 3.0),
159
+ })
160
+
161
+ def __len__(self): return len(self.samples)
162
+ def __getitem__(self, idx):
163
+ s = self.samples[idx]
164
+ enc = tokenizer(s["text"], max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt")
165
+ return {
166
+ "input_ids": enc["input_ids"].squeeze(0),
167
+ "attn_mask": enc["attention_mask"].squeeze(0),
168
+ "party": torch.tensor(s["party"], dtype=torch.long),
169
+ "category": torch.tensor(s["category"], dtype=torch.long),
170
+ "ofn": torch.tensor(s["ofn"], dtype=torch.float),
171
+ "tn": torch.tensor(s["tn"], dtype=torch.float),
172
+ "msg_type": torch.tensor(s["msg_type"], dtype=torch.long),
173
+ "price_t": torch.tensor(s["price_t"], dtype=torch.float),
174
+ }
175
+
176
+ class PositionalEncoding(nn.Module):
177
+ def __init__(self, d: int, max_len: int = 512):
178
+ super().__init__()
179
+ self.drop = nn.Dropout(0.1)
180
+ pe = torch.zeros(max_len, d)
181
+ pos = torch.arange(max_len).unsqueeze(1).float()
182
+ div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d))
183
+ pe[:, 0::2] = torch.sin(pos * div)
184
+ pe[:, 1::2] = torch.cos(pos * div)
185
+ self.register_buffer("pe", pe.unsqueeze(0))
186
+
187
+ def forward(self, x): return self.drop(x + self.pe[:, :x.size(1)])
188
+
189
+ class NegotiationTransformer(nn.Module):
190
+ def __init__(self):
191
+ super().__init__()
192
+ self.emb = nn.Embedding(30522, D_MODEL, padding_idx=0)
193
+ self.pos = PositionalEncoding(D_MODEL)
194
+ enc_layer = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, FFN_DIM, dropout=0.1, batch_first=True, norm_first=True)
195
+ self.encoder = nn.TransformerEncoder(enc_layer, N_LAYERS)
196
+ self.p_emb = nn.Embedding(2, 32)
197
+ self.c_emb = nn.Embedding(len(CATEGORIES), 64)
198
+ self.fusion = nn.Sequential(nn.Linear(D_MODEL + 32 + 64 + 2, D_MODEL), nn.GELU())
199
+ self.msg_head = nn.Linear(D_MODEL, len(MSG_TYPES))
200
+ self.px_head = nn.Sequential(nn.Linear(D_MODEL, 128), nn.GELU(), nn.Linear(128, 1), nn.Softplus())
201
+
202
+ def forward(self, ids, mask, party, cat, ofn, tn):
203
+ x = self.pos(self.emb(ids))
204
+ x = self.encoder(x, src_key_padding_mask=(mask == 0))
205
+ cls = x[:, 0]
206
+ f = self.fusion(torch.cat([cls, self.p_emb(party), self.c_emb(cat), torch.stack([ofn, tn], dim=1)], dim=1))
207
+ return self.msg_head(f), self.px_head(f).squeeze(1)
208
+
209
+ # ── Background Training Daemon ────────────────────────────────
210
+ def _training_thread_target(n_sessions: int, epochs: int, batch_size: int, lr: float):
211
+ global GLOBAL_MODEL
212
+ try:
213
+ STATE.log(f"Starting data generation: {n_sessions:,} sessions (~{n_sessions*5:,} rows)")
214
+
215
+ # Generation runs in main memory, yields CPU often enough
216
+ rows = generate_sessions(n_sessions)
217
+ STATE.log(f"Data generated. Tokenizing into dataset...")
218
+
219
+ dataset = NegotiationDataset(rows)
220
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
221
+
222
+ STATE.log(f"Dataset ready: {len(dataset):,} samples. Initializing Model...")
223
+ model = NegotiationTransformer().to(DEVICE)
224
+
225
+ opt = AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
226
+ sch = CosineAnnealingLR(opt, T_max=epochs)
227
+ ce, mse = nn.CrossEntropyLoss(), nn.MSELoss()
228
+
229
+ with STATE.lock:
230
+ STATE.total_epochs = epochs
231
+ STATE.losses = []
232
+
233
+ STATE.log("Entering Training Loop (CPU mode).")
234
+ total_batches = len(loader)
235
+
236
+ for ep in range(epochs):
237
+ model.train()
238
+ ep_loss = 0.0
239
+ with STATE.lock:
240
+ STATE.current_epoch = ep + 1
241
+
242
+ for i, batch in enumerate(loader):
243
+ if i % max(1, total_batches // 10) == 0:
244
+ with STATE.lock:
245
+ STATE.batch_progress = f"Epoch {ep+1}/{epochs} | Batch {i}/{total_batches}"
246
+
247
+ opt.zero_grad()
248
+ mt_logits, px_pred = model(
249
+ batch["input_ids"].to(DEVICE), batch["attn_mask"].to(DEVICE),
250
+ batch["party"].to(DEVICE), batch["category"].to(DEVICE),
251
+ batch["ofn"].to(DEVICE), batch["tn"].to(DEVICE)
252
+ )
253
+ loss = ce(mt_logits, batch["msg_type"].to(DEVICE)) + 0.5 * mse(px_pred, batch["price_t"].to(DEVICE))
254
+ loss.backward()
255
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
256
+ opt.step()
257
+ ep_loss += loss.item()
258
+
259
+ sch.step()
260
+ avg_loss = ep_loss / max(total_batches, 1)
261
+
262
+ with STATE.lock:
263
+ STATE.losses.append(avg_loss)
264
+ STATE.log(f"Epoch {ep+1} complete. Loss: {avg_loss:.4f}")
265
+
266
+ STATE.log("Training complete. Applying weights to Global Model.")
267
+ model.eval()
268
+ GLOBAL_MODEL = model
269
+ with STATE.lock:
270
+ STATE.model_ready = True
271
+
272
+ except Exception as e:
273
+ STATE.log(f"ERROR: {str(e)}")
274
+ finally:
275
+ with STATE.lock:
276
+ STATE.is_running = False
277
+
278
+ def start_training(n_sessions, epochs, batch_size, lr):
279
+ with STATE.lock:
280
+ if STATE.is_running:
281
+ return "Training is already running!"
282
+ STATE.is_running = True
283
+ STATE.logs = []
284
+ STATE.batch_progress = "Initializing..."
285
+
286
+ t = threading.Thread(target=_training_thread_target, args=(int(n_sessions), int(epochs), int(batch_size), float(lr)), daemon=True)
287
+ t.start()
288
+ return "Background training thread triggered."
289
+
290
+ # ── Inference with Pre-built Templates ────────────────────────
291
+ def _get_template_message(msg_type: str, price: float, item: str, is_buyer: bool) -> str:
292
+ """The 'Mouth': Translates the Model's strategy (msg_type, price) into prose."""
293
+ px = f"${price:,.2f}"
294
+ if is_buyer:
295
+ templates = {
296
+ "offer": f"I'll start the bidding at {px} for the {item}.",
297
+ "counter": random.choice([f"I can offer {px}.", f"How about {px}?", f"My counter is {px}."]),
298
+ "accept": f"{px} works for me. I'll take it.",
299
+ "reject": "That's too high for my budget, I have to pass.",
300
+ "stall": "I need to check my budget and get back to you.",
301
+ "exit": "We're too far apart. Moving on."
302
+ }
303
+ else:
304
+ templates = {
305
+ "offer": f"I'm looking to get {px} for the {item}.",
306
+ "counter": random.choice([f"I can drop to {px}.", f"Best I can do right now is {px}.", f"Let's meet at {px}."]),
307
+ "accept": f"You got a deal at {px}.",
308
+ "reject": "I can't go that low.",
309
+ "stall": "Let me see if I have other offers first.",
310
+ "exit": "I can't sell it for that. Goodbye."
311
+ }
312
+ return templates.get(msg_type, f"Action: {msg_type} at {px}")
313
+
314
+ def predict(category, item, list_price, current_offer, history_text, party_str):
315
+ if GLOBAL_MODEL is None:
316
+ return "Model not trained yet. Run training tab first.", "", "", ""
317
+
318
+ try:
319
+ lp, cp = float(list_price), float(current_offer)
320
+ is_buyer = (party_str == "Buyer")
321
+ pty = 1 if is_buyer else 0
322
+
323
+ enc = tokenizer(history_text or "(start)", max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt")
324
+ turns = len([l for l in history_text.strip().split("\n") if l.strip()])
325
+
326
+ p = torch.tensor([pty], dtype=torch.long)
327
+ c = torch.tensor([CAT2IDX.get(category, 0)], dtype=torch.long)
328
+ ofn = torch.tensor([min(cp/lp, 3.0)], dtype=torch.float)
329
+ tn = torch.tensor([min(turns/20.0, 1.0)], dtype=torch.float)
330
+
331
+ with torch.no_grad():
332
+ mt_logits, px = GLOBAL_MODEL(enc["input_ids"], enc["attention_mask"], p, c, ofn, tn)
333
+
334
+ mt_idx = mt_logits.argmax(dim=1).item()
335
+ msg_out = IDX2MSG[mt_idx]
336
+ price_out = round(px.item() * lp, 2)
337
+
338
+ prose_msg = _get_template_message(msg_out, price_out, item, is_buyer)
339
+ probs = F.softmax(mt_logits, dim=1).squeeze().tolist()
340
+ prob_str = " | ".join(f"{MSG_TYPES[i]}: {probs[i]:.2f}" for i in range(len(MSG_TYPES)))
341
+
342
+ return msg_out, f"${price_out:,.2f}", prose_msg, prob_str
343
+ except Exception as e:
344
+ return "Error", "", str(e), ""
345
+
346
+ # ── Dashboard UI (Polling) ────────────────────────────────────
347
+ def refresh_dashboard():
348
+ with STATE.lock:
349
+ is_run = STATE.is_running
350
+ status = "🟒 ACTIVE - " + STATE.batch_progress if is_run else "πŸ”΄ IDLE"
351
+ log_text = "\n".join(STATE.logs)
352
+ losses = list(STATE.losses)
353
+ ready = "βœ… Ready" if STATE.model_ready else "❌ Needs Training"
354
+
355
+ fig, ax = plt.subplots(figsize=(6, 3))
356
+ if losses:
357
+ ax.plot(range(1, len(losses)+1), losses, "b-o", markersize=4)
358
+ ax.set_title("Training Loss")
359
+ else:
360
+ ax.text(0.5, 0.5, 'No data yet', ha='center', va='center', alpha=0.5)
361
+ ax.grid(alpha=0.3)
362
+ plt.tight_layout()
363
+
364
+ return status, log_text, fig, ready
365
+
366
+ # ── Gradio ────────────────────────────────────────────────────
367
+ with gr.Blocks(title="ANP | HF Daemon Trainer", theme=gr.themes.Soft()) as demo:
368
+ gr.Markdown("# ANP Background Trainer\nTrains on the HF free CPU via a background thread while you watch.")
369
+
370
+ with gr.Tab("Dashboard & Training"):
371
+ with gr.Row():
372
+ n_sessions = gr.Number(value=40000, label="Sessions (~200k rows)")
373
+ epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
374
+ batch_size = gr.Slider(16, 256, value=64, step=16, label="Batch Size")
375
+ lr = gr.Number(value=5e-4, label="Learning Rate")
376
+
377
+ tr_btn = gr.Button("πŸš€ Start Background Training", variant="primary")
378
+
379
+ gr.Markdown("### Real-Time Status *(Polls automatically)*")
380
+ status_box = gr.Textbox(label="Thread Status", interactive=False)
381
+ with gr.Row():
382
+ log_box = gr.Textbox(label="System Logs", lines=12, interactive=False)
383
+ plt_out = gr.Plot(label="Loss Curve")
384
+
385
+ # Gradio Timer continuously updates the dashboard every 3 seconds
386
+ gr.Timer(3, active=True).tick(
387
+ fn=refresh_dashboard,
388
+ outputs=[status_box, log_box, plt_out, gr.Textbox(visible=False)]
389
+ )
390
+
391
+ with gr.Tab("Inference Sandbox"):
392
+ ready_indicator = gr.Textbox(label="Model Status", interactive=False)
393
+ gr.Timer(5, active=True).tick(fn=lambda: "βœ… Ready" if STATE.model_ready else "❌ Needs Training", outputs=[ready_indicator])
394
+
395
+ with gr.Row():
396
+ d_cat = gr.Dropdown(CATEGORIES, value="used_car", label="Category")
397
+ d_pty = gr.Radio(["Seller","Buyer"], value="Buyer", label="Party to Simulate")
398
+ with gr.Row():
399
+ d_lp = gr.Number(value=18500, label="List Price ($)")
400
+ d_co = gr.Number(value=16000, label="Current Offer ($)")
401
+
402
+ d_item = gr.Textbox(value="2019 Honda Civic", label="Item Name (for template)")
403
+ d_hist = gr.Textbox(lines=4, label="Turn History", placeholder="Seller: Asking $18,500.\nBuyer: I can do $15,000.")
404
+
405
+ d_btn = gr.Button("Generate Move & Message", variant="primary")
406
+
407
+ with gr.Row():
408
+ d_msg = gr.Textbox(label="Action Head")
409
+ d_px = gr.Textbox(label="Pricing Head")
410
+ d_prose = gr.Textbox(label="Generated Message (Template)", lines=2)
411
+ d_prob = gr.Textbox(label="Action Probabilities")
412
+
413
+ d_btn.click(predict, inputs=[d_cat, d_item, d_lp, d_co, d_hist, d_pty], outputs=[d_msg, d_px, d_prose, d_prob])
414
+
415
+ tr_btn.click(start_training, inputs=[n_sessions, epochs, batch_size, lr], outputs=[status_box])
416
+
417
+ # Launch blocking the main thread, daemons will run in background
418
+ demo.launch(server_name="0.0.0.0", server_port=7860)