AbeBhatti commited on
Commit
6017516
·
1 Parent(s): 9e20ed6

Add all code, exclude large model weights

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. agent/arbitragent.py +29 -12
  3. agent/bluff_detector.py +92 -11
  4. demo/display.py +133 -156
  5. demo/run_demo.py +146 -177
  6. envs/arbitragent_env.py +190 -0
  7. training/arbitragent_colab.ipynb +171 -267
  8. training/bluff_training.log +16 -0
  9. training/checkpoints/bluff_classifier_tokenizer/tokenizer.json +0 -0
  10. training/checkpoints/bluff_classifier_tokenizer/tokenizer_config.json +14 -0
  11. training/checkpoints/phase2_final/README.md +67 -0
  12. training/checkpoints/phase2_final/chat_template.jinja +15 -0
  13. training/checkpoints/phase2_final/checkpoint-100/chat_template.jinja +15 -0
  14. training/checkpoints/phase2_final/checkpoint-100/config.json +32 -0
  15. training/checkpoints/phase2_final/checkpoint-100/generation_config.json +9 -0
  16. training/checkpoints/phase2_final/checkpoint-100/tokenizer.json +0 -0
  17. training/checkpoints/phase2_final/checkpoint-100/tokenizer_config.json +19 -0
  18. training/checkpoints/phase2_final/checkpoint-100/trainer_state.json +304 -0
  19. training/checkpoints/phase2_final/checkpoint-200/chat_template.jinja +15 -0
  20. training/checkpoints/phase2_final/checkpoint-200/config.json +32 -0
  21. training/checkpoints/phase2_final/checkpoint-200/generation_config.json +9 -0
  22. training/checkpoints/phase2_final/checkpoint-200/tokenizer.json +0 -0
  23. training/checkpoints/phase2_final/checkpoint-200/tokenizer_config.json +19 -0
  24. training/checkpoints/phase2_final/checkpoint-200/trainer_state.json +574 -0
  25. training/checkpoints/phase2_final/config.json +32 -0
  26. training/checkpoints/phase2_final/generation_config.json +9 -0
  27. training/checkpoints/phase2_final/tokenizer.json +0 -0
  28. training/checkpoints/phase2_final/tokenizer_config.json +19 -0
  29. training/checkpoints/unified_final/README.md +67 -0
  30. training/checkpoints/unified_final/chat_template.jinja +15 -0
  31. training/checkpoints/unified_final/checkpoint-100/chat_template.jinja +15 -0
  32. training/checkpoints/unified_final/checkpoint-100/config.json +32 -0
  33. training/checkpoints/unified_final/checkpoint-100/generation_config.json +9 -0
  34. training/checkpoints/unified_final/checkpoint-100/tokenizer.json +0 -0
  35. training/checkpoints/unified_final/checkpoint-100/tokenizer_config.json +19 -0
  36. training/checkpoints/unified_final/checkpoint-100/trainer_state.json +304 -0
  37. training/checkpoints/unified_final/checkpoint-200/chat_template.jinja +15 -0
  38. training/checkpoints/unified_final/checkpoint-200/config.json +32 -0
  39. training/checkpoints/unified_final/checkpoint-200/generation_config.json +9 -0
  40. training/checkpoints/unified_final/checkpoint-200/tokenizer.json +0 -0
  41. training/checkpoints/unified_final/checkpoint-200/tokenizer_config.json +19 -0
  42. training/checkpoints/unified_final/checkpoint-200/trainer_state.json +574 -0
  43. training/checkpoints/unified_final/config.json +32 -0
  44. training/checkpoints/unified_final/generation_config.json +9 -0
  45. training/checkpoints/unified_final/tokenizer.json +0 -0
  46. training/checkpoints/unified_final/tokenizer_config.json +19 -0
  47. training/checkpoints/unified_final/unified_reward_log.json +810 -0
  48. training/parse_poker.py +136 -0
  49. training/plot_phase2.py +24 -0
  50. training/train_bluff_classifier.py +142 -0
.gitignore CHANGED
@@ -13,3 +13,7 @@ selfplay_states_test.json
13
  proj_context.md
14
  session_progress.md
15
  HF_TOKEN
 
 
 
 
 
13
  proj_context.md
14
  session_progress.md
15
  HF_TOKEN
16
+ *.safetensors
17
+ *.bin
18
+ *.safetensors
19
+ *.bin
agent/arbitragent.py CHANGED
@@ -332,18 +332,22 @@ class ArbitrAgent:
332
  self.route_graph.mark_dead(edge.edge_id)
333
  continue
334
 
335
- # Bluff detection: inspect full thread via bluff_detector.
336
  signals = analyze_from_sim(c.sim, resp or "")
337
 
338
- # Log full bluff reasoning trace to structured log.
 
 
 
 
339
  self._structured_log.append(
340
  {
341
  "event": "bluff_analysis",
342
  "phase": 3,
 
343
  "seller_id": c.seller_id,
344
  "item": c.item,
345
- "turn": c.sim.turn,
346
- "seller_message": resp,
347
  "signals": {
348
  "timing_tell": signals.timing_tell,
349
  "size_tell": signals.size_tell,
@@ -352,10 +356,22 @@ class ArbitrAgent:
352
  "bluff_score": signals.bluff_score,
353
  "is_bluff": signals.is_bluff,
354
  },
355
- "thread_history": list(getattr(c.sim, "thread_history", [])),
 
356
  }
357
  )
358
 
 
 
 
 
 
 
 
 
 
 
 
359
  if verbose:
360
  print(
361
  f"[bluff_analysis {c.seller_id}] "
@@ -367,11 +383,10 @@ class ArbitrAgent:
367
  f"is_bluff={signals.is_bluff}"
368
  )
369
 
370
- # When a bluff is detected, immediately deploy coalition pressure.
371
  if signals.is_bluff:
372
  current_offer = float(c.sim.current_offer)
373
- # Simple heuristic counter: push meaningfully below stated offer.
374
- offer = max(1, int(current_offer - 8))
375
  pressure_msg = (
376
  "I have a trade offer from another seller that makes this less urgent for me — "
377
  f"can you do ${offer}?"
@@ -397,12 +412,16 @@ class ArbitrAgent:
397
  }
398
  )
399
 
400
- # Update entry cost after pressure-induced move.
401
  for edge in edges:
402
  self.route_graph.update_entry_cost(edge.edge_id, c.sim.current_offer)
 
 
 
 
 
 
403
 
404
  for edge in edges:
405
- # If we have a confirmed downstream target by this turn, upgrade probability.
406
  target_index = int(edge.trade_target_id.split("_")[1])
407
  if (edge.buy_item, target_index) in confirmed_targets:
408
  self.route_graph.update_confirmation_probability(
@@ -410,8 +429,6 @@ class ArbitrAgent:
410
  )
411
  self.route_graph.mark_confirmed(edge.edge_id)
412
 
413
- # Adjust seller reliability slightly based on bluff score.
414
- # Higher bluff score → more room to push → treat as slightly *higher* edge value.
415
  new_reliability = min(
416
  1.0, edge.seller_reliability + 0.1 * float(signals.bluff_score)
417
  )
 
332
  self.route_graph.mark_dead(edge.edge_id)
333
  continue
334
 
335
+ # Bluff detection: inspect full thread via BluffDetector.
336
  signals = analyze_from_sim(c.sim, resp or "")
337
 
338
+ # Unverified floor claim: formulaic language present but not flagged as full bluff.
339
+ formulaic_present = signals.formulaic_tell > 0
340
+
341
+ # Log full bluff reasoning: turn, seller_id, bluff_score, signals dict, action_taken.
342
+ action_taken = msg # the agent message we just sent before this response
343
  self._structured_log.append(
344
  {
345
  "event": "bluff_analysis",
346
  "phase": 3,
347
+ "turn": c.sim.turn,
348
  "seller_id": c.seller_id,
349
  "item": c.item,
350
+ "bluff_score": signals.bluff_score,
 
351
  "signals": {
352
  "timing_tell": signals.timing_tell,
353
  "size_tell": signals.size_tell,
 
356
  "bluff_score": signals.bluff_score,
357
  "is_bluff": signals.is_bluff,
358
  },
359
+ "action_taken": action_taken,
360
+ "seller_message": resp,
361
  }
362
  )
363
 
364
+ if not signals.is_bluff and formulaic_present:
365
+ self._structured_log.append(
366
+ {
367
+ "event": "unverified_floor_claim",
368
+ "phase": 3,
369
+ "turn": c.sim.turn,
370
+ "seller_id": c.seller_id,
371
+ "seller_message": resp,
372
+ }
373
+ )
374
+
375
  if verbose:
376
  print(
377
  f"[bluff_analysis {c.seller_id}] "
 
383
  f"is_bluff={signals.is_bluff}"
384
  )
385
 
386
+ # When a bluff is detected, deploy coalition pressure: floor - 4.
387
  if signals.is_bluff:
388
  current_offer = float(c.sim.current_offer)
389
+ offer = max(1, int(current_offer - 4))
 
390
  pressure_msg = (
391
  "I have a trade offer from another seller that makes this less urgent for me — "
392
  f"can you do ${offer}?"
 
412
  }
413
  )
414
 
 
415
  for edge in edges:
416
  self.route_graph.update_entry_cost(edge.edge_id, c.sim.current_offer)
417
+ # Bluff means seller has room — update confirmation probability upward.
418
+ for edge in edges:
419
+ self.route_graph.update_confirmation_probability(
420
+ edge.edge_id,
421
+ confirmation_probability=min(1.0, edge.confirmation_probability + 0.15),
422
+ )
423
 
424
  for edge in edges:
 
425
  target_index = int(edge.trade_target_id.split("_")[1])
426
  if (edge.buy_item, target_index) in confirmed_targets:
427
  self.route_graph.update_confirmation_probability(
 
429
  )
430
  self.route_graph.mark_confirmed(edge.edge_id)
431
 
 
 
432
  new_reliability = min(
433
  1.0, edge.seller_reliability + 0.1 * float(signals.bluff_score)
434
  )
agent/bluff_detector.py CHANGED
@@ -1,24 +1,100 @@
1
  """
2
  bluff_detector.py — bluff signal extraction for ArbitrAgent.
3
 
4
- This module exposes a small, deterministic API that inspects a seller's
5
- response in the context of a thread and extracts four bluff signals:
6
-
7
- 1. timing_tell
8
- 2. size_tell
9
- 3. formulaic_tell
10
- 4. pattern_tell
11
-
12
- The overall bluff_score is a weighted sum of these four signals. A response
13
- is flagged as a bluff when bluff_score > 0.6.
14
  """
15
 
16
  from __future__ import annotations
17
 
18
  import re
19
  from dataclasses import dataclass
 
20
  from typing import Any, Dict, List, Mapping, Optional, Sequence
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  FORMULAIC_PHRASES: List[str] = [
24
  "lowest i can go",
@@ -100,12 +176,17 @@ def analyze_bluff(
100
  for key in DEFAULT_WEIGHTS.keys()
101
  }
102
 
103
- bluff_score = (
104
  timing * norm_weights["timing_tell"]
105
  + size * norm_weights["size_tell"]
106
  + formulaic * norm_weights["formulaic_tell"]
107
  + pattern * norm_weights["pattern_tell"]
108
  )
 
 
 
 
 
109
  is_bluff = bluff_score > 0.6
110
 
111
  return BluffSignals(
 
1
  """
2
  bluff_detector.py — bluff signal extraction for ArbitrAgent.
3
 
4
+ Exposes four rule-based signals (timing, size, formulaic, pattern) and an optional
5
+ learned DistilBERT classifier trained on IRC poker bluff labels. Combined score:
6
+ bluff_score = 0.6 * learned_bluff_score + 0.4 * rule_score
7
+ is_bluff when bluff_score > 0.6.
 
 
 
 
 
 
8
  """
9
 
10
  from __future__ import annotations
11
 
12
  import re
13
  from dataclasses import dataclass
14
+ from pathlib import Path
15
  from typing import Any, Dict, List, Mapping, Optional, Sequence
16
 
17
+ # Lazy-loaded learned classifier (only on first use)
18
+ _bluff_classifier_model = None
19
+ _bluff_classifier_tokenizer = None
20
+
21
+
22
+ def _get_bluff_classifier():
23
+ """Lazy-load bluff_classifier.pt and tokenizer from training/checkpoints."""
24
+ global _bluff_classifier_model, _bluff_classifier_tokenizer
25
+ if _bluff_classifier_model is not None:
26
+ return _bluff_classifier_model, _bluff_classifier_tokenizer
27
+ pt_path = Path(__file__).resolve().parent.parent / "training" / "checkpoints" / "bluff_classifier.pt"
28
+ tok_dir = Path(__file__).resolve().parent.parent / "training" / "checkpoints" / "bluff_classifier_tokenizer"
29
+ if not pt_path.exists() or not tok_dir.exists():
30
+ return None, None
31
+ try:
32
+ import torch
33
+ from transformers import AutoTokenizer, AutoModel
34
+ _bluff_classifier_tokenizer = AutoTokenizer.from_pretrained(str(tok_dir))
35
+
36
+ class _BluffClassifierModule(torch.nn.Module):
37
+ def __init__(self):
38
+ super().__init__()
39
+ self.encoder = AutoModel.from_pretrained("distilbert-base-uncased")
40
+ self.head = torch.nn.Linear(self.encoder.config.hidden_size, 2)
41
+
42
+ def forward(self, input_ids, attention_mask=None, **kwargs):
43
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
44
+ return self.head(out.last_hidden_state[:, 0, :])
45
+
46
+ _bluff_classifier_model = _BluffClassifierModule()
47
+ _bluff_classifier_model.load_state_dict(torch.load(pt_path, map_location="cpu", weights_only=True))
48
+ _bluff_classifier_model.eval()
49
+ return _bluff_classifier_model, _bluff_classifier_tokenizer
50
+ except Exception:
51
+ return None, None
52
+
53
+
54
+ def _thread_and_message_to_text(thread_history: Sequence[Mapping[str, Any]], seller_message: str) -> str:
55
+ """Convert thread + seller message into text matching poker training format (Position. Preflop. Flop. Turn. River. Pot)."""
56
+ parts: List[str] = []
57
+ for entry in thread_history:
58
+ if "agent" in entry:
59
+ parts.append(str(entry["agent"])[:80])
60
+ if "seller" in entry:
61
+ parts.append(str(entry["seller"])[:80])
62
+ # Map to poker-like: Preflop / Flop / Turn / River
63
+ preflop = parts[0] if len(parts) > 0 else "-"
64
+ flop = parts[1] if len(parts) > 1 else "-"
65
+ turn = parts[2] if len(parts) > 2 else "-"
66
+ river = seller_message[:200] if seller_message else "-"
67
+ return f"Position 1 of 2. Preflop: {preflop}. Flop: {flop}. Turn: {turn}. River: {river}. Pot: 0."
68
+
69
+
70
+ def learned_bluff_score(message: str, thread_history: Sequence[Mapping[str, Any]]) -> float:
71
+ """
72
+ Run learned DistilBERT classifier on (message + thread). Returns P(bluff) in [0, 1].
73
+ Returns 0.0 if classifier not loaded.
74
+ """
75
+ model, tokenizer = _get_bluff_classifier()
76
+ if model is None or tokenizer is None:
77
+ return 0.0
78
+ text = _thread_and_message_to_text(thread_history, message)
79
+ try:
80
+ import torch
81
+ enc = tokenizer(
82
+ text,
83
+ truncation=True,
84
+ max_length=128,
85
+ padding="max_length",
86
+ return_tensors="pt",
87
+ )
88
+ with torch.no_grad():
89
+ logits = model(
90
+ input_ids=enc["input_ids"],
91
+ attention_mask=enc["attention_mask"],
92
+ )
93
+ probs = torch.softmax(logits, dim=1)
94
+ return float(probs[0, 1].item()) # class 1 = bluff
95
+ except Exception:
96
+ return 0.0
97
+
98
 
99
  FORMULAIC_PHRASES: List[str] = [
100
  "lowest i can go",
 
176
  for key in DEFAULT_WEIGHTS.keys()
177
  }
178
 
179
+ rule_score = (
180
  timing * norm_weights["timing_tell"]
181
  + size * norm_weights["size_tell"]
182
  + formulaic * norm_weights["formulaic_tell"]
183
  + pattern * norm_weights["pattern_tell"]
184
  )
185
+ learned = learned_bluff_score(seller_message, thread_history)
186
+ if _bluff_classifier_model is not None:
187
+ bluff_score = 0.6 * learned + 0.4 * rule_score
188
+ else:
189
+ bluff_score = rule_score
190
  is_bluff = bluff_score > 0.6
191
 
192
  return BluffSignals(
demo/display.py CHANGED
@@ -1,9 +1,17 @@
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass, field
4
  from typing import Any, Dict, List, Optional
5
 
6
- from rich.columns import Columns
7
  from rich.console import Console
8
  from rich.panel import Panel
9
  from rich.table import Table
@@ -13,7 +21,7 @@ from rich.text import Text
13
  @dataclass
14
  class ThreadMessage:
15
  turn: int
16
- sender: str # "agent" or "seller"
17
  text: str
18
  is_bluff: bool = False
19
 
@@ -23,20 +31,32 @@ class ThreadState:
23
  seller_id: str
24
  item: str
25
  archetype: str
26
- status: str = "active" # "active" | "dead" | "confirmed"
 
27
  messages: List[ThreadMessage] = field(default_factory=list)
28
  bluff_signals: Optional[Dict[str, float]] = None
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class NegotiationDisplay:
32
  """
33
- Rich-based terminal display for the ArbitrAgent demo.
34
-
35
- Responsibilities:
36
- - Show all active negotiation threads as side-by-side panels.
37
- - Highlight bluff detection in yellow with individual signals.
38
- - Use red for dead routes / threads and green for confirmed routes.
39
- - Render a final panel with budget → entry cost → exit value → return multiple.
40
  """
41
 
42
  def __init__(self, console: Optional[Console] = None) -> None:
@@ -47,172 +67,129 @@ class NegotiationDisplay:
47
  threads: List[ThreadState],
48
  route_summaries: List[Dict[str, Any]],
49
  budget: float,
 
50
  final_metrics: Optional[Dict[str, Any]] = None,
51
  checkpoints: Optional[Dict[str, bool]] = None,
52
  ) -> None:
53
- """Render the full demo view."""
54
  self.console.clear()
55
 
56
- thread_panels = [self._build_thread_panel(t) for t in threads]
57
- if thread_panels:
58
- self.console.print(Columns(thread_panels, expand=True, equal=True))
59
-
60
- # Routes + ROI panel at the bottom
61
- summary_panel = self._build_summary_panel(
62
- route_summaries=route_summaries,
63
- budget=budget,
64
- final_metrics=final_metrics,
65
- checkpoints=checkpoints or {},
66
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  self.console.print()
68
- self.console.print(summary_panel)
69
 
70
- # ------------------------------------------------------------------ #
71
- # Panel builders
72
- # ------------------------------------------------------------------ #
73
- def _build_thread_panel(self, thread: ThreadState) -> Panel:
74
- # Border colors by status
75
- border_style = "bright_white"
76
- if thread.status == "dead":
77
- border_style = "red"
78
- elif thread.status == "confirmed":
79
- border_style = "green"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- title = f"{thread.seller_id} {thread.item} • {thread.archetype}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  table = Table.grid(padding=(0, 1))
84
- table.expand = True
85
  table.add_column("Speaker", style="bold", no_wrap=True)
86
  table.add_column("Text", overflow="fold")
87
-
88
- # Only show the last few turns to keep panels readable.
89
- for msg in thread.messages[-8:]:
90
  speaker = "you" if msg.sender == "agent" else "seller"
91
  style = "cyan" if msg.sender == "agent" else "white"
92
  text = Text(msg.text, style=style)
93
  if msg.is_bluff:
94
- # Yellow highlight for bluff detection.
95
  text.stylize("black on yellow")
96
  table.add_row(speaker, text)
97
-
98
- # Bluff signal breakdown, if present.
99
  if thread.bluff_signals:
100
- sig = thread.bluff_signals
101
- sig_table = Table.grid(padding=(0, 1))
102
- sig_table.add_column(justify="left", no_wrap=True)
103
- sig_table.add_column(justify="right", no_wrap=True)
104
- sig_table.add_row(
105
- "[bold yellow]Bluff detected[/bold yellow]",
106
- f"[yellow]score={sig.get('bluff_score', 0.0):.2f}[/yellow]",
107
- )
108
- for key in ("timing_tell", "size_tell", "formulaic_tell", "pattern_tell"):
109
- if key in sig:
110
- label = key.replace("_", " ")
111
- sig_table.add_row(label, f"{sig[key]:.2f}")
112
-
113
- table.add_row("", sig_table)
114
-
115
- return Panel(
116
- table,
117
- title=title,
118
- border_style=border_style,
119
- padding=(1, 1),
120
- )
121
-
122
- def _build_summary_panel(
123
- self,
124
- route_summaries: List[Dict[str, Any]],
125
- budget: float,
126
- final_metrics: Optional[Dict[str, Any]],
127
- checkpoints: Dict[str, bool],
128
- ) -> Panel:
129
- table = Table.grid(padding=(0, 2))
130
- table.expand = True
131
-
132
- # Left: route statuses
133
- routes_sub = Table(
134
- show_header=True,
135
- header_style="bold",
136
- title="Route Graph",
137
- title_style="bold",
138
- )
139
- routes_sub.add_column("Route", no_wrap=True)
140
- routes_sub.add_column("Status", no_wrap=True)
141
- routes_sub.add_column("Δ", justify="right", no_wrap=True)
142
- routes_sub.add_column("Score", justify="right", no_wrap=True)
143
-
144
- for row in route_summaries:
145
- margin = row["exit_value"] - row["entry_cost"]
146
- status = row["status"]
147
- status_style = {
148
- "dead": "red",
149
- "confirmed": "green",
150
- "soft": "yellow",
151
- }.get(status, "white")
152
- routes_sub.add_row(
153
- row["edge_id"],
154
- f"[{status_style}]{status}[/{status_style}]",
155
- f"{margin:.2f}",
156
- f"{row['score']:.2f}",
157
- )
158
-
159
- # Right: ROI + checkpoints
160
- roi_sub = Table(
161
- show_header=False,
162
- box=None,
163
- title="Capital Deployment",
164
- title_style="bold",
165
- )
166
- roi_sub.add_column("Label", no_wrap=True)
167
- roi_sub.add_column("Value", no_wrap=True)
168
-
169
- entry_cost = None
170
- exit_value = None
171
- return_multiple = None
172
-
173
- if final_metrics is not None:
174
- entry_cost = final_metrics.get("entry_cost")
175
- exit_value = final_metrics.get("exit_value")
176
- return_multiple = final_metrics.get("return_multiple")
177
-
178
- roi_sub.add_row("Budget", f"$ {budget:.2f}")
179
- if entry_cost is not None:
180
- roi_sub.add_row("Entry cost", f"$ {entry_cost:.2f}")
181
- if exit_value is not None:
182
- roi_sub.add_row("Exit value", f"$ {exit_value:.2f}")
183
- if return_multiple is not None:
184
- roi_sub.add_row("Return multiple", f"{return_multiple:.2f}x")
185
-
186
- # Checkpoints list
187
- checkpoints_sub = Table(
188
- show_header=False,
189
- box=None,
190
- title="Demo Checkpoints",
191
- title_style="bold",
192
- )
193
- checkpoints_sub.add_column("State", no_wrap=True)
194
-
195
- labels = [
196
- ("multi_thread_view", "Threads visible"),
197
- ("bluff_detected", "Bluff flagged"),
198
- ("dead_route_seen", "Dead route surfaced"),
199
- ("route_confirmed", "Route confirmed"),
200
- ("execution_complete", "Executed & logged"),
201
- ]
202
- for key, label in labels:
203
- done = checkpoints.get(key, False)
204
- style = "green" if done else "dim"
205
- marker = "●" if done else "○"
206
- checkpoints_sub.add_row(f"[{style}]{marker} {label}[/{style}]")
207
-
208
- table.add_row(routes_sub, roi_sub, checkpoints_sub)
209
- return Panel(
210
- table,
211
- title="ArbitrAgent — $20 → Multi-Route Arbitrage",
212
- border_style="cyan",
213
- padding=(1, 1),
214
- )
215
 
216
 
217
  __all__ = ["NegotiationDisplay", "ThreadState", "ThreadMessage"]
218
-
 
1
+ """
2
+ Rich terminal UI for the ArbitrAgent demo.
3
+
4
+ Panel 1: NEGOTIATION THREADS — one row per seller (name, item, current offer, status).
5
+ Panel 2: LIVE EVENT LOG — scrolling [BLUFF DETECTED], [GOOD OUTCOME], [HUMAN-ALIGNED MOVE], [ROUTE KILLED].
6
+ Panel 3: ROUTE GRAPH — route_id, entry, exit, score, status.
7
+ Panel 4: FINAL RESULT — Budget → Deployed → Final Value → Return, route and why.
8
+ """
9
+
10
  from __future__ import annotations
11
 
12
  from dataclasses import dataclass, field
13
  from typing import Any, Dict, List, Optional
14
 
 
15
  from rich.console import Console
16
  from rich.panel import Panel
17
  from rich.table import Table
 
21
  @dataclass
22
  class ThreadMessage:
23
  turn: int
24
+ sender: str
25
  text: str
26
  is_bluff: bool = False
27
 
 
31
  seller_id: str
32
  item: str
33
  archetype: str
34
+ status: str = "active" # "active" | "pending" | "confirmed" | "dead"
35
+ current_offer: Optional[float] = None
36
  messages: List[ThreadMessage] = field(default_factory=list)
37
  bluff_signals: Optional[Dict[str, float]] = None
38
 
39
 
40
+ # Event types for the live event log
41
+ BluffDetectedEvent = Dict[str, Any] # seller_name, turn, timing_tell, size_tell, formulaic_tell, pattern_tell, action_taken
42
+ GoodOutcomeEvent = Dict[str, Any] # route_id, entry_cost, exit_value, return_multiple, did_not_accept_floor
43
+ HumanAlignedEvent = Dict[str, Any] # phase_name, action_taken, similarity_pct
44
+ RouteKilledEvent = Dict[str, Any] # seller_name, reason, capital_preserved
45
+
46
+
47
+ def _status_style(status: str) -> str:
48
+ if status == "confirmed":
49
+ return "green"
50
+ if status == "active":
51
+ return "yellow"
52
+ if status == "dead":
53
+ return "red"
54
+ return "white" # pending
55
+
56
+
57
  class NegotiationDisplay:
58
  """
59
+ Live terminal UI: negotiation threads, event log, route graph, final result.
 
 
 
 
 
 
60
  """
61
 
62
  def __init__(self, console: Optional[Console] = None) -> None:
 
67
  threads: List[ThreadState],
68
  route_summaries: List[Dict[str, Any]],
69
  budget: float,
70
+ event_log: Optional[List[Dict[str, Any]]] = None,
71
  final_metrics: Optional[Dict[str, Any]] = None,
72
  checkpoints: Optional[Dict[str, bool]] = None,
73
  ) -> None:
 
74
  self.console.clear()
75
 
76
+ # Panel 1 NEGOTIATION THREADS
77
+ threads_table = Table(
78
+ show_header=True,
79
+ header_style="bold",
80
+ title="NEGOTIATION THREADS",
81
+ title_style="bold",
 
 
 
 
82
  )
83
+ threads_table.add_column("Seller", no_wrap=True)
84
+ threads_table.add_column("Item", no_wrap=True)
85
+ threads_table.add_column("Current offer", justify="right", no_wrap=True)
86
+ threads_table.add_column("Status", no_wrap=True)
87
+ for t in threads:
88
+ offer_str = f"${t.current_offer:.2f}" if t.current_offer is not None else "—"
89
+ style = _status_style(t.status)
90
+ threads_table.add_row(
91
+ t.seller_id,
92
+ t.item,
93
+ offer_str,
94
+ f"[{style}]{t.status}[/{style}]",
95
+ )
96
+ self.console.print(Panel(threads_table, border_style="cyan", padding=(0, 1)))
97
  self.console.print()
 
98
 
99
+ # Panel 2 — LIVE EVENT LOG (scrolling, last N events)
100
+ events = event_log or []
101
+ log_lines: List[Any] = []
102
+ for ev in events[-30:]:
103
+ kind = ev.get("type") or ev.get("event")
104
+ if kind == "bluff_detected":
105
+ log_lines.append(Text("[BLUFF DETECTED]", style="bold yellow"))
106
+ log_lines.append(Text(f" {ev.get('seller_name', ev.get('seller_id', ''))}, turn {ev.get('turn', '')}"))
107
+ log_lines.append(Text(f" ✦ timing tell: {ev.get('timing_tell', 0):.2f}"))
108
+ log_lines.append(Text(f" ✦ size tell: {ev.get('size_tell', 0):.2f}"))
109
+ log_lines.append(Text(f" ✦ formulaic tell: {ev.get('formulaic_tell', 0):.2f}"))
110
+ log_lines.append(Text(f" ✦ pattern tell: {ev.get('pattern_tell', 0):.2f}"))
111
+ log_lines.append(Text(f" → action taken: {ev.get('action_taken', '')[:80]}..."))
112
+ log_lines.append(Text(""))
113
+ elif kind == "good_outcome":
114
+ log_lines.append(Text("[GOOD OUTCOME]", style="bold green"))
115
+ log_lines.append(Text(f" route {ev.get('route_id', '')}, entry ${ev.get('entry_cost', 0):.2f}, exit ${ev.get('exit_value', 0):.2f}, return {ev.get('return_multiple', 0):.2f}x"))
116
+ log_lines.append(Text(" ✦ did not accept stated floor"))
117
+ log_lines.append(Text(""))
118
+ elif kind == "human_aligned":
119
+ log_lines.append(Text("[HUMAN-ALIGNED MOVE]", style="bold blue"))
120
+ log_lines.append(Text(f" {ev.get('phase_name', '')}: {str(ev.get('action_taken', ''))[:60]}..."))
121
+ log_lines.append(Text(f" ✦ matches human Diplomacy pattern: {ev.get('similarity_pct', 0):.0f}% similarity"))
122
+ log_lines.append(Text(""))
123
+ elif kind == "route_killed":
124
+ log_lines.append(Text("[ROUTE KILLED]", style="bold red"))
125
+ log_lines.append(Text(f" {ev.get('seller_name', ev.get('seller_id', ''))}, {ev.get('reason', '')}"))
126
+ log_lines.append(Text(" ✦ capital preserved, pivoting"))
127
+ log_lines.append(Text(""))
128
+
129
+ if log_lines:
130
+ log_content = Text()
131
+ for line in log_lines:
132
+ log_content.append_text(line)
133
+ log_content.append("\n")
134
+ self.console.print(Panel(log_content, title="LIVE EVENT LOG", border_style="dim", padding=(0, 1), height=14))
135
+ else:
136
+ self.console.print(Panel("(no events yet)", title="LIVE EVENT LOG", border_style="dim", padding=(0, 1), height=6))
137
+ self.console.print()
138
 
139
+ # Panel 3 ROUTE GRAPH
140
+ route_table = Table(
141
+ show_header=True,
142
+ header_style="bold",
143
+ title="ROUTE GRAPH",
144
+ title_style="bold",
145
+ )
146
+ route_table.add_column("route_id", no_wrap=True)
147
+ route_table.add_column("entry", justify="right", no_wrap=True)
148
+ route_table.add_column("exit", justify="right", no_wrap=True)
149
+ route_table.add_column("score", justify="right", no_wrap=True)
150
+ route_table.add_column("status", no_wrap=True)
151
+ for row in route_summaries:
152
+ st = row.get("status", "soft")
153
+ route_table.add_row(
154
+ row.get("edge_id", ""),
155
+ f"${row.get('entry_cost', 0):.2f}",
156
+ f"${row.get('exit_value', 0):.2f}",
157
+ f"{row.get('score', 0):.2f}",
158
+ f"[{_status_style(st)}]{st}[/{_status_style(st)}]",
159
+ )
160
+ self.console.print(Panel(route_table, border_style="cyan", padding=(0, 1)))
161
+ self.console.print()
162
 
163
+ # Panel 4 — FINAL RESULT (when available)
164
+ if final_metrics is not None:
165
+ entry = final_metrics.get("entry_cost")
166
+ exit_val = final_metrics.get("exit_value")
167
+ ret = final_metrics.get("return_multiple")
168
+ route_id = final_metrics.get("route_id", "")
169
+ why = final_metrics.get("why", "best scored confirmed route")
170
+ line1 = f"Budget: ${budget:.1f} → Deployed: ${entry:.2f} → Final Value: ${exit_val:.2f} → Return: {ret:.2f}x"
171
+ line2 = f"Route: {route_id} — {why}"
172
+ self.console.print(Panel(f"[bold]{line1}[/bold]\n\n{line2}", title="FINAL RESULT", border_style="green", padding=(1, 2)))
173
+ elif checkpoints and checkpoints.get("execution_complete"):
174
+ self.console.print(Panel("No route executed. Capital preserved.", title="FINAL RESULT", border_style="yellow", padding=(1, 2)))
175
+
176
+ # Legacy API: build thread panel per thread (for side-by-side thread view if needed)
177
+ def _build_thread_panel(self, thread: ThreadState) -> Panel:
178
+ border_style = _status_style(thread.status)
179
+ title = f"{thread.seller_id} • {thread.item}"
180
  table = Table.grid(padding=(0, 1))
 
181
  table.add_column("Speaker", style="bold", no_wrap=True)
182
  table.add_column("Text", overflow="fold")
183
+ for msg in thread.messages[-6:]:
 
 
184
  speaker = "you" if msg.sender == "agent" else "seller"
185
  style = "cyan" if msg.sender == "agent" else "white"
186
  text = Text(msg.text, style=style)
187
  if msg.is_bluff:
 
188
  text.stylize("black on yellow")
189
  table.add_row(speaker, text)
 
 
190
  if thread.bluff_signals:
191
+ table.add_row("", f"[yellow]bluff_score={thread.bluff_signals.get('bluff_score', 0):.2f}[/yellow]")
192
+ return Panel(table, title=title, border_style=border_style, padding=(0, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  __all__ = ["NegotiationDisplay", "ThreadState", "ThreadMessage"]
 
demo/run_demo.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import argparse
@@ -8,37 +14,49 @@ import time
8
  from dataclasses import asdict
9
  from typing import Any, Dict, List
10
 
11
- # Ensure project root is on sys.path when run as `python demo/run_demo.py`.
12
  PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
  if PROJECT_ROOT not in sys.path:
14
  sys.path.insert(0, PROJECT_ROOT)
15
 
16
- from agent.arbitragent import ArbitrAgent, SellerCandidate # type: ignore
17
  from agent.bluff_detector import analyze_from_sim
18
  from agent.route_graph import RouteEdge
19
  from demo.display import NegotiationDisplay, ThreadMessage, ThreadState
20
  from simulation.scenario import get_scenario
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  class DemoArbitrAgent(ArbitrAgent):
24
  """
25
- Thin wrapper around ArbitrAgent that:
26
- - Drives the existing five-phase loop.
27
- - Streams state into the Rich-based NegotiationDisplay.
28
- - Builds a structured JSON log of the entire episode.
29
  """
30
 
 
 
 
 
31
  def run_with_display(
32
  self,
33
  budget: float,
34
- sleep_per_tick: float = 0.7,
 
35
  ) -> Dict[str, Any]:
36
  self.budget = float(budget)
37
-
38
  sellers, trade_targets = get_scenario()
39
  display = NegotiationDisplay()
 
40
 
41
- # Checkpoint flags for the demo.
42
  checkpoints: Dict[str, bool] = {
43
  "multi_thread_view": False,
44
  "bluff_detected": False,
@@ -47,111 +65,79 @@ class DemoArbitrAgent(ArbitrAgent):
47
  "execution_complete": False,
48
  }
49
 
50
- # Thread state tracking per seller.
51
  threads: Dict[str, ThreadState] = {}
52
 
53
- def get_thread_for_candidate(cand: SellerCandidate) -> ThreadState:
54
  if cand.seller_id not in threads:
55
  threads[cand.seller_id] = ThreadState(
56
  seller_id=cand.seller_id,
57
  item=cand.item,
58
  archetype=cand.archetype,
 
59
  )
60
  return threads[cand.seller_id]
61
 
62
- # Structured log scaffold.
 
 
 
 
63
  log: Dict[str, Any] = {
64
  "budget": self.budget,
 
 
65
  "events": [],
66
  "routes": [],
67
  "final": {},
68
  "checkpoints": checkpoints,
69
  }
70
-
71
  start_time = time.time()
72
 
73
- # -----------------------------
74
- # Phase 1: Scout + soft inquiry
75
- # -----------------------------
76
  candidates = self._phase1_scout(sellers)
77
-
78
  for cand in candidates:
79
  log["events"].append(
80
- {
81
- "phase": 1,
82
- "type": "candidate_scored",
83
- "seller_id": cand.seller_id,
84
- "item": cand.item,
85
- "score": cand.score,
86
- "listing_price": cand.listing_price,
87
- "resale_value": cand.resale_value,
88
- }
89
  )
90
 
91
- # Open soft inquiries and populate initial threads.
92
  for cand in candidates:
93
- thread = get_thread_for_candidate(cand)
94
  msg = f"hey, is the {cand.item} still available? any room on price?"
95
  resp = cand.sim.step(msg)
96
-
97
- thread.messages.append(
98
- ThreadMessage(turn=cand.sim.turn, sender="agent", text=msg)
99
- )
100
  if resp is not None:
101
- thread.messages.append(
102
- ThreadMessage(turn=cand.sim.turn, sender="seller", text=resp)
103
- )
104
 
105
- log["events"].append(
106
- {
107
- "phase": 1,
108
- "type": "soft_inquiry",
109
- "seller_id": cand.seller_id,
110
- "agent_message": msg,
111
- "seller_response": resp,
112
- }
113
- )
114
-
115
- # Initial multi-thread view.
116
  checkpoints["multi_thread_view"] = True
 
117
  display.render(
118
  threads=list(threads.values()),
119
  route_summaries=self.route_graph.summary(),
120
  budget=self.budget,
 
121
  final_metrics=None,
122
  checkpoints=checkpoints,
123
  )
124
  time.sleep(sleep_per_tick)
125
 
126
- # -----------------------------
127
- # Phase 2: Route Mapping
128
- # -----------------------------
129
- seller_to_edges = self._phase2_build_routes(
130
- candidates=candidates,
131
- trade_targets=trade_targets,
132
- verbose=False,
133
- )
134
-
135
- # Render after routes created (still soft).
136
  display.render(
137
  threads=list(threads.values()),
138
  route_summaries=self.route_graph.summary(),
139
  budget=self.budget,
 
140
  final_metrics=None,
141
  checkpoints=checkpoints,
142
  )
143
  time.sleep(sleep_per_tick)
144
 
145
- # -----------------------------
146
- # Phase 3: Pressure & Confirm
147
- # -----------------------------
148
- if candidates:
149
- max_turn = max(t["confirmed_at_turn"] for t in trade_targets)
150
- else:
151
- max_turn = 0
152
-
153
  for turn in range(2, max_turn + 1):
154
- # Which downstream trade targets are confirmed by this turn?
155
  confirmed_targets = {
156
  (t["item"], idx)
157
  for idx, t in enumerate(trade_targets)
@@ -159,66 +145,46 @@ class DemoArbitrAgent(ArbitrAgent):
159
  }
160
 
161
  for cand in candidates:
162
- edges_for_seller: List[RouteEdge] = seller_to_edges.get(
163
- cand.seller_id, []
164
- )
165
 
166
- # Track threads even if seller has no explicit route edges (e.g., ghoster).
167
- thread = get_thread_for_candidate(cand)
168
-
169
- # Death / ghosting.
170
  if cand.sim.is_dead():
171
  if thread.status != "dead":
172
  thread.status = "dead"
173
  checkpoints["dead_route_seen"] = True
174
- log["events"].append(
175
- {
176
- "phase": 3,
177
- "turn": turn,
178
- "type": "route_dead",
179
- "seller_id": cand.seller_id,
180
- }
181
- )
182
- # If there are edges, mark them dead in the graph.
183
  for edge in edges_for_seller:
184
  self.route_graph.mark_dead(edge.edge_id)
185
  continue
186
 
187
- # Do we have a confirmed downstream target by this turn?
188
  has_confirmed_downstream = any(
189
- (edge.buy_item, int(edge.trade_target_id.split("_")[1]))
190
- in confirmed_targets
191
- for edge in edges_for_seller
192
  )
193
-
194
  if has_confirmed_downstream:
195
  agent_msg = (
196
  f"i have another buyer interested in the {cand.item}, "
197
- "but i'd prefer to buy from you if we can make the numbers work. "
198
- "could you do a bit better on price?"
199
  )
200
  else:
201
- agent_msg = (
202
- f"just checking back on the {cand.item} — any flexibility on your price at all?"
203
- )
204
 
205
  resp = cand.sim.step(agent_msg)
206
-
207
- # Log messages into thread.
208
- thread.messages.append(
209
- ThreadMessage(turn=cand.sim.turn, sender="agent", text=agent_msg)
210
- )
211
  if resp is not None:
212
- thread.messages.append(
213
- ThreadMessage(turn=cand.sim.turn, sender="seller", text=resp)
214
- )
215
 
216
- # Bluff analysis if we got a response.
217
  if resp is not None:
218
  signals = analyze_from_sim(cand.sim, resp)
219
  if signals.is_bluff:
220
  checkpoints["bluff_detected"] = True
221
- # Mark the most recent seller message as bluff-highlighted.
222
  thread.messages[-1].is_bluff = True
223
  thread.bluff_signals = {
224
  "timing_tell": signals.timing_tell,
@@ -227,70 +193,81 @@ class DemoArbitrAgent(ArbitrAgent):
227
  "pattern_tell": signals.pattern_tell,
228
  "bluff_score": signals.bluff_score,
229
  }
230
- log["events"].append(
231
- {
232
- "phase": 3,
233
- "turn": turn,
234
- "type": "bluff_detected",
235
- "seller_id": cand.seller_id,
236
- "message": resp,
237
- "signals": asdict(signals),
238
- }
 
 
 
 
 
 
 
 
 
 
239
  )
240
-
241
- log["events"].append(
242
- {
243
- "phase": 3,
244
- "turn": turn,
245
- "type": "negotiation_turn",
246
- "seller_id": cand.seller_id,
247
- "agent_message": agent_msg,
248
- "seller_response": resp,
249
- }
250
- )
251
-
252
- # Update entry cost with latest offer.
 
 
 
 
253
  for edge in edges_for_seller:
254
  self.route_graph.update_entry_cost(edge.edge_id, cand.sim.current_offer)
255
 
256
- # If seller ghosted after this message, mark dead.
257
  if cand.sim.is_dead():
258
  if thread.status != "dead":
259
  thread.status = "dead"
260
  checkpoints["dead_route_seen"] = True
 
 
 
 
 
 
261
  for edge in edges_for_seller:
262
  self.route_graph.mark_dead(edge.edge_id)
263
  continue
264
 
265
- # Upgrade confirmation probability when downstream target has confirmed.
266
  for edge in edges_for_seller:
267
  target_index = int(edge.trade_target_id.split("_")[1])
268
  if (edge.buy_item, target_index) in confirmed_targets:
269
- self.route_graph.update_confirmation_probability(
270
- edge.edge_id, confirmation_probability=0.9
271
- )
272
  self.route_graph.mark_confirmed(edge.edge_id)
273
  thread.status = "confirmed"
274
  checkpoints["route_confirmed"] = True
275
 
276
- # Render this turn.
277
  display.render(
278
  threads=list(threads.values()),
279
  route_summaries=self.route_graph.summary(),
280
  budget=self.budget,
 
281
  final_metrics=None,
282
  checkpoints=checkpoints,
283
  )
284
  time.sleep(sleep_per_tick)
285
 
286
- # -----------------------------
287
- # Phase 4: Route Scoring
288
- # -----------------------------
289
  self.route_graph.prune_below_threshold()
290
-
291
- # -----------------------------
292
- # Phase 5: Execute
293
- # -----------------------------
294
  best = self.route_graph.best_route()
295
  route_summary = self.route_graph.summary()
296
  log["routes"] = route_summary
@@ -303,12 +280,11 @@ class DemoArbitrAgent(ArbitrAgent):
303
  "return_multiple": 1.0,
304
  "duration_seconds": time.time() - start_time,
305
  }
 
306
  else:
307
  profit = best.exit_value - best.entry_cost
308
  final_value = self.budget - best.entry_cost + best.exit_value
309
- route_multiple = (
310
- best.exit_value / best.entry_cost if best.entry_cost > 0 else 0.0
311
- )
312
  final = {
313
  "best_route": {
314
  "edge_id": best.edge_id,
@@ -322,64 +298,57 @@ class DemoArbitrAgent(ArbitrAgent):
322
  "return_multiple": route_multiple,
323
  "duration_seconds": time.time() - start_time,
324
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  checkpoints["execution_complete"] = True
327
  log["final"] = final
328
 
329
- # Final render with ROI panel filled.
330
- final_route = None
331
- if final["best_route"] is not None:
332
- final_route = final["best_route"]
333
  display.render(
334
  threads=list(threads.values()),
335
  route_summaries=route_summary,
336
  budget=self.budget,
337
- final_metrics={
338
- "entry_cost": final_route["entry_cost"] if final_route else None,
339
- "exit_value": final_route["exit_value"] if final_route else None,
340
- "return_multiple": final["return_multiple"],
341
- },
342
  checkpoints=checkpoints,
343
  )
344
-
345
  return log
346
 
347
 
348
  def main() -> None:
349
- parser = argparse.ArgumentParser(
350
- description="Run the ArbitrAgent Rich demo (90-second negotiation walkthrough)."
351
- )
352
- parser.add_argument(
353
- "--budget",
354
- type=float,
355
- default=20.0,
356
- help="Starting cash budget for the agent (default: 20.0).",
357
- )
358
- parser.add_argument(
359
- "--sleep",
360
- type=float,
361
- default=15.0,
362
- help="Seconds to pause between display updates (default: 15.0, ~90s total demo).",
363
- )
364
- parser.add_argument(
365
- "--log-path",
366
- type=str,
367
- default=None,
368
- help="Optional path to write the structured JSON log. If omitted, prints to stdout only.",
369
- )
370
  args = parser.parse_args()
371
 
372
- agent = DemoArbitrAgent(budget=args.budget, min_route_score=1.0)
373
- log = agent.run_with_display(budget=args.budget, sleep_per_tick=args.sleep)
 
 
374
 
375
  json_str = json.dumps(log, indent=2, default=float)
376
- if args.log_path:
377
- with open(args.log_path, "w") as f:
378
- f.write(json_str)
379
  print("\n=== Structured Demo Log (JSON) ===")
380
- print(json_str)
 
381
 
382
 
383
  if __name__ == "__main__":
384
  main()
385
-
 
1
+ """
2
+ Demo entry point: budget, scenario, full 5-phase agent loop with Rich display.
3
+ Loads unified_final checkpoint if present, else phase2_final. Saves log to demo/sample_run_log.json.
4
+ Must complete in under 90 seconds.
5
+ """
6
+
7
  from __future__ import annotations
8
 
9
  import argparse
 
14
  from dataclasses import asdict
15
  from typing import Any, Dict, List
16
 
 
17
  PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18
  if PROJECT_ROOT not in sys.path:
19
  sys.path.insert(0, PROJECT_ROOT)
20
 
21
+ from agent.arbitragent import ArbitrAgent, SellerCandidate
22
  from agent.bluff_detector import analyze_from_sim
23
  from agent.route_graph import RouteEdge
24
  from demo.display import NegotiationDisplay, ThreadMessage, ThreadState
25
  from simulation.scenario import get_scenario
26
 
27
 
28
+ def _resolve_checkpoint_path() -> str | None:
29
+ """Unified_final if exists, else phase2_final."""
30
+ unified = os.path.join(PROJECT_ROOT, "training", "checkpoints", "unified_final")
31
+ phase2 = os.path.join(PROJECT_ROOT, "training", "checkpoints", "phase2_final")
32
+ if os.path.isdir(unified):
33
+ return unified
34
+ if os.path.isdir(phase2):
35
+ return phase2
36
+ return None
37
+
38
+
39
  class DemoArbitrAgent(ArbitrAgent):
40
  """
41
+ Runs full 5-phase loop with display and event log.
42
+ Uses checkpoint path for future model loading; currently heuristic agent.
 
 
43
  """
44
 
45
+ def __init__(self, budget: float = 20.0, min_route_score: float = 1.0, checkpoint_path: str | None = None):
46
+ super().__init__(budget=budget, min_route_score=min_route_score)
47
+ self.checkpoint_path = checkpoint_path or _resolve_checkpoint_path()
48
+
49
  def run_with_display(
50
  self,
51
  budget: float,
52
+ scenario: str = "standard_demo",
53
+ sleep_per_tick: float = 0.5,
54
  ) -> Dict[str, Any]:
55
  self.budget = float(budget)
 
56
  sellers, trade_targets = get_scenario()
57
  display = NegotiationDisplay()
58
+ event_log: List[Dict[str, Any]] = []
59
 
 
60
  checkpoints: Dict[str, bool] = {
61
  "multi_thread_view": False,
62
  "bluff_detected": False,
 
65
  "execution_complete": False,
66
  }
67
 
 
68
  threads: Dict[str, ThreadState] = {}
69
 
70
+ def get_thread(cand: SellerCandidate) -> ThreadState:
71
  if cand.seller_id not in threads:
72
  threads[cand.seller_id] = ThreadState(
73
  seller_id=cand.seller_id,
74
  item=cand.item,
75
  archetype=cand.archetype,
76
+ current_offer=cand.sim.current_offer,
77
  )
78
  return threads[cand.seller_id]
79
 
80
+ def sync_offers():
81
+ for c in candidates:
82
+ t = get_thread(c)
83
+ t.current_offer = c.sim.current_offer
84
+
85
  log: Dict[str, Any] = {
86
  "budget": self.budget,
87
+ "scenario": scenario,
88
+ "checkpoint_path": self.checkpoint_path,
89
  "events": [],
90
  "routes": [],
91
  "final": {},
92
  "checkpoints": checkpoints,
93
  }
 
94
  start_time = time.time()
95
 
96
+ # Phase 1
 
 
97
  candidates = self._phase1_scout(sellers)
 
98
  for cand in candidates:
99
  log["events"].append(
100
+ {"phase": 1, "type": "candidate_scored", "seller_id": cand.seller_id, "item": cand.item, "score": cand.score}
 
 
 
 
 
 
 
 
101
  )
102
 
 
103
  for cand in candidates:
104
+ thread = get_thread(cand)
105
  msg = f"hey, is the {cand.item} still available? any room on price?"
106
  resp = cand.sim.step(msg)
107
+ thread.messages.append(ThreadMessage(turn=cand.sim.turn, sender="agent", text=msg))
 
 
 
108
  if resp is not None:
109
+ thread.messages.append(ThreadMessage(turn=cand.sim.turn, sender="seller", text=resp))
110
+ thread.current_offer = cand.sim.current_offer
111
+ log["events"].append({"phase": 1, "type": "soft_inquiry", "seller_id": cand.seller_id, "agent_message": msg, "seller_response": resp})
112
 
 
 
 
 
 
 
 
 
 
 
 
113
  checkpoints["multi_thread_view"] = True
114
+ sync_offers()
115
  display.render(
116
  threads=list(threads.values()),
117
  route_summaries=self.route_graph.summary(),
118
  budget=self.budget,
119
+ event_log=event_log,
120
  final_metrics=None,
121
  checkpoints=checkpoints,
122
  )
123
  time.sleep(sleep_per_tick)
124
 
125
+ # Phase 2
126
+ seller_to_edges = self._phase2_build_routes(candidates=candidates, trade_targets=trade_targets, verbose=False)
127
+ sync_offers()
 
 
 
 
 
 
 
128
  display.render(
129
  threads=list(threads.values()),
130
  route_summaries=self.route_graph.summary(),
131
  budget=self.budget,
132
+ event_log=event_log,
133
  final_metrics=None,
134
  checkpoints=checkpoints,
135
  )
136
  time.sleep(sleep_per_tick)
137
 
138
+ # Phase 3
139
+ max_turn = max(t["confirmed_at_turn"] for t in trade_targets) if candidates else 0
 
 
 
 
 
 
140
  for turn in range(2, max_turn + 1):
 
141
  confirmed_targets = {
142
  (t["item"], idx)
143
  for idx, t in enumerate(trade_targets)
 
145
  }
146
 
147
  for cand in candidates:
148
+ edges_for_seller: List[RouteEdge] = seller_to_edges.get(cand.seller_id, [])
149
+ thread = get_thread(cand)
 
150
 
 
 
 
 
151
  if cand.sim.is_dead():
152
  if thread.status != "dead":
153
  thread.status = "dead"
154
  checkpoints["dead_route_seen"] = True
155
+ event_log.append({
156
+ "type": "route_killed",
157
+ "seller_name": cand.seller_id,
158
+ "reason": "ghosting",
159
+ "capital_preserved": True,
160
+ })
161
+ log["events"].append({"phase": 3, "turn": turn, "type": "route_dead", "seller_id": cand.seller_id})
 
 
162
  for edge in edges_for_seller:
163
  self.route_graph.mark_dead(edge.edge_id)
164
  continue
165
 
 
166
  has_confirmed_downstream = any(
167
+ (e.buy_item, int(e.trade_target_id.split("_")[1])) in confirmed_targets
168
+ for e in edges_for_seller
 
169
  )
 
170
  if has_confirmed_downstream:
171
  agent_msg = (
172
  f"i have another buyer interested in the {cand.item}, "
173
+ "but i'd prefer to buy from you if we can make the numbers work. could you do a bit better on price?"
 
174
  )
175
  else:
176
+ agent_msg = f"just checking back on the {cand.item} — any flexibility on your price at all?"
 
 
177
 
178
  resp = cand.sim.step(agent_msg)
179
+ thread.messages.append(ThreadMessage(turn=cand.sim.turn, sender="agent", text=agent_msg))
 
 
 
 
180
  if resp is not None:
181
+ thread.messages.append(ThreadMessage(turn=cand.sim.turn, sender="seller", text=resp))
182
+ thread.current_offer = cand.sim.current_offer
 
183
 
 
184
  if resp is not None:
185
  signals = analyze_from_sim(cand.sim, resp)
186
  if signals.is_bluff:
187
  checkpoints["bluff_detected"] = True
 
188
  thread.messages[-1].is_bluff = True
189
  thread.bluff_signals = {
190
  "timing_tell": signals.timing_tell,
 
193
  "pattern_tell": signals.pattern_tell,
194
  "bluff_score": signals.bluff_score,
195
  }
196
+ event_log.append({
197
+ "type": "bluff_detected",
198
+ "seller_name": cand.seller_id,
199
+ "turn": cand.sim.turn,
200
+ "timing_tell": signals.timing_tell,
201
+ "size_tell": signals.size_tell,
202
+ "formulaic_tell": signals.formulaic_tell,
203
+ "pattern_tell": signals.pattern_tell,
204
+ "action_taken": "coalition pressure (see next message)",
205
+ })
206
+ log["events"].append({
207
+ "phase": 3, "turn": turn, "type": "bluff_detected",
208
+ "seller_id": cand.seller_id, "message": resp, "signals": asdict(signals),
209
+ })
210
+ # Coalition pressure: floor - 4
211
+ offer = max(1, int(float(cand.sim.current_offer) - 4))
212
+ pressure_msg = (
213
+ "I have a trade offer from another seller that makes this less urgent for me — "
214
+ f"can you do ${offer}?"
215
  )
216
+ pressure_resp = cand.sim.step(pressure_msg)
217
+ thread.messages.append(ThreadMessage(turn=cand.sim.turn, sender="agent", text=pressure_msg))
218
+ if pressure_resp is not None:
219
+ thread.messages.append(ThreadMessage(turn=cand.sim.turn, sender="seller", text=pressure_resp))
220
+ thread.current_offer = cand.sim.current_offer
221
+ event_log[-1]["action_taken"] = pressure_msg
222
+ for edge in edges_for_seller:
223
+ self.route_graph.update_entry_cost(edge.edge_id, cand.sim.current_offer)
224
+ for edge in edges_for_seller:
225
+ self.route_graph.update_confirmation_probability(
226
+ edge.edge_id, confirmation_probability=min(1.0, edge.confirmation_probability + 0.15)
227
+ )
228
+
229
+ log["events"].append({
230
+ "phase": 3, "turn": turn, "type": "negotiation_turn",
231
+ "seller_id": cand.seller_id, "agent_message": agent_msg, "seller_response": resp,
232
+ })
233
  for edge in edges_for_seller:
234
  self.route_graph.update_entry_cost(edge.edge_id, cand.sim.current_offer)
235
 
 
236
  if cand.sim.is_dead():
237
  if thread.status != "dead":
238
  thread.status = "dead"
239
  checkpoints["dead_route_seen"] = True
240
+ event_log.append({
241
+ "type": "route_killed",
242
+ "seller_name": cand.seller_id,
243
+ "reason": "stopped responding",
244
+ "capital_preserved": True,
245
+ })
246
  for edge in edges_for_seller:
247
  self.route_graph.mark_dead(edge.edge_id)
248
  continue
249
 
 
250
  for edge in edges_for_seller:
251
  target_index = int(edge.trade_target_id.split("_")[1])
252
  if (edge.buy_item, target_index) in confirmed_targets:
253
+ self.route_graph.update_confirmation_probability(edge.edge_id, confirmation_probability=0.9)
 
 
254
  self.route_graph.mark_confirmed(edge.edge_id)
255
  thread.status = "confirmed"
256
  checkpoints["route_confirmed"] = True
257
 
258
+ sync_offers()
259
  display.render(
260
  threads=list(threads.values()),
261
  route_summaries=self.route_graph.summary(),
262
  budget=self.budget,
263
+ event_log=event_log,
264
  final_metrics=None,
265
  checkpoints=checkpoints,
266
  )
267
  time.sleep(sleep_per_tick)
268
 
269
+ # Phase 4 & 5
 
 
270
  self.route_graph.prune_below_threshold()
 
 
 
 
271
  best = self.route_graph.best_route()
272
  route_summary = self.route_graph.summary()
273
  log["routes"] = route_summary
 
280
  "return_multiple": 1.0,
281
  "duration_seconds": time.time() - start_time,
282
  }
283
+ final_metrics_display = None
284
  else:
285
  profit = best.exit_value - best.entry_cost
286
  final_value = self.budget - best.entry_cost + best.exit_value
287
+ route_multiple = best.exit_value / best.entry_cost if best.entry_cost > 0 else 0.0
 
 
288
  final = {
289
  "best_route": {
290
  "edge_id": best.edge_id,
 
298
  "return_multiple": route_multiple,
299
  "duration_seconds": time.time() - start_time,
300
  }
301
+ event_log.append({
302
+ "type": "good_outcome",
303
+ "route_id": best.edge_id,
304
+ "entry_cost": best.entry_cost,
305
+ "exit_value": best.exit_value,
306
+ "return_multiple": route_multiple,
307
+ "did_not_accept_floor": checkpoints.get("bluff_detected", False),
308
+ })
309
+ final_metrics_display = {
310
+ "entry_cost": best.entry_cost,
311
+ "exit_value": best.exit_value,
312
+ "return_multiple": route_multiple,
313
+ "route_id": best.edge_id,
314
+ "why": "best scored confirmed route (bluff detected and pressure applied)" if checkpoints.get("bluff_detected") else "best scored confirmed route",
315
+ }
316
 
317
  checkpoints["execution_complete"] = True
318
  log["final"] = final
319
 
 
 
 
 
320
  display.render(
321
  threads=list(threads.values()),
322
  route_summaries=route_summary,
323
  budget=self.budget,
324
+ event_log=event_log,
325
+ final_metrics=final_metrics_display,
 
 
 
326
  checkpoints=checkpoints,
327
  )
 
328
  return log
329
 
330
 
331
  def main() -> None:
332
+ parser = argparse.ArgumentParser(description="Run ArbitrAgent demo (full 5-phase loop, <90s).")
333
+ parser.add_argument("--budget", type=float, default=20.0, help="Starting budget (default: 20).")
334
+ parser.add_argument("--scenario", type=str, default="standard_demo", help="Scenario name (default: standard_demo).")
335
+ parser.add_argument("--sleep", type=float, default=0.5, help="Seconds per display tick (default: 0.5).")
336
+ parser.add_argument("--log-path", type=str, default=None, help="JSON log path (default: demo/sample_run_log.json).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  args = parser.parse_args()
338
 
339
+ log_path = args.log_path or os.path.join(PROJECT_ROOT, "demo", "sample_run_log.json")
340
+ checkpoint_path = _resolve_checkpoint_path()
341
+ agent = DemoArbitrAgent(budget=args.budget, min_route_score=1.0, checkpoint_path=checkpoint_path)
342
+ log = agent.run_with_display(budget=args.budget, scenario=args.scenario, sleep_per_tick=args.sleep)
343
 
344
  json_str = json.dumps(log, indent=2, default=float)
345
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
346
+ with open(log_path, "w") as f:
347
+ f.write(json_str)
348
  print("\n=== Structured Demo Log (JSON) ===")
349
+ print(f"Saved to {log_path}")
350
+ print(json_str[:2000] + "..." if len(json_str) > 2000 else json_str)
351
 
352
 
353
  if __name__ == "__main__":
354
  main()
 
envs/arbitragent_env.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ArbitrAgentEnv — OpenEnv 0.2.1 unified environment with three reward signals.
3
+
4
+ total_reward = 0.35 * accuracy_reward + 0.35 * outcome_reward + 0.30 * bluff_reward
5
+
6
+ - accuracy_reward: cosine similarity between model action embedding and human action embedding.
7
+ - outcome_reward: keyword scoring (coalition, pressure, clean close vs premature concession).
8
+ - bluff_reward: BluffDetector on synthetic seller message; reward correct flag, penalize missed bluff.
9
+ """
10
+
11
+ import json
12
+ import random
13
+ import re
14
+ import numpy as np
15
+ from openenv.env import Env
16
+ from sentence_transformers import SentenceTransformer
17
+
18
+ def _extract_human_orders(state_text: str) -> str:
19
+ """Extract 'My orders this turn: ...' from Diplomacy state_text as human action."""
20
+ match = re.search(r"My orders this turn:\s*(.+?)(?:\n|$)", state_text, re.DOTALL | re.IGNORECASE)
21
+ if match:
22
+ return match.group(1).strip()
23
+ return state_text[:200]
24
+
25
+
26
+ # Synthetic seller message used for bluff_reward when state has no thread (Diplomacy states).
27
+ SYNTHETIC_BLUFF_MESSAGE = "look i really cant go lower than $30, thats my final offer. been getting a lot of interest so"
28
+ SYNTHETIC_BLUFF_PROFILE = {
29
+ "response_speed": "fast",
30
+ "archetype": "bluffer",
31
+ }
32
+ SYNTHETIC_THREAD = [
33
+ {"turn": 1, "agent": "Can you do $35?"},
34
+ {"turn": 1, "seller": "How about $32?"},
35
+ {"turn": 2, "agent": "I need a better price."},
36
+ {"turn": 2, "seller": SYNTHETIC_BLUFF_MESSAGE},
37
+ ]
38
+
39
+
40
+ class ArbitrAgentEnv(Env):
41
+ """
42
+ Single OpenEnv 0.2.1 environment combining accuracy (human move alignment),
43
+ outcome (negotiation language), and bluff (detection) rewards.
44
+ """
45
+
46
+ def __init__(self, data_path: str = "training/data/selfplay_states.json", seed=None):
47
+ self.data_path = data_path
48
+ self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
49
+ if seed is not None:
50
+ random.seed(seed)
51
+ np.random.seed(seed)
52
+ with open(data_path, "r") as f:
53
+ self.all_states = json.load(f)
54
+ self.current_state = None
55
+ self.round = 0
56
+ self.max_rounds = 10
57
+ self.done = False
58
+ self._last_reward_breakdown = None
59
+
60
+ def reset(self):
61
+ self.current_state = random.choice(self.all_states)
62
+ self.round = 0
63
+ self.done = False
64
+ self._last_reward_breakdown = None
65
+ obs = self._get_observation()
66
+ info = {
67
+ "round": self.round,
68
+ "phase": self.current_state.get("phase", ""),
69
+ "power": self.current_state.get("power", ""),
70
+ }
71
+ return obs, info
72
+
73
+ def step(self, action: str):
74
+ self.round += 1
75
+ action = action or "(no action)"
76
+ action_lower = action.lower()
77
+
78
+ accuracy = self._accuracy_reward(action)
79
+ outcome = self._outcome_reward(action_lower)
80
+ bluff = self._bluff_reward(action_lower)
81
+
82
+ total = 0.35 * accuracy + 0.35 * outcome + 0.30 * bluff
83
+ self._last_reward_breakdown = {"accuracy": accuracy, "outcome": outcome, "bluff": bluff, "total": total}
84
+
85
+ self.current_state = self._get_next_state()
86
+ self.done = (
87
+ self.round >= self.max_rounds
88
+ or self.current_state.get("is_winner", False)
89
+ or self.current_state.get("is_eliminated", False)
90
+ )
91
+ obs = self._get_observation()
92
+ info = {
93
+ "round": self.round,
94
+ "accuracy": accuracy,
95
+ "outcome": outcome,
96
+ "bluff": bluff,
97
+ "total": total,
98
+ "phase": self.current_state.get("phase", ""),
99
+ "power": self.current_state.get("power", ""),
100
+ }
101
+ return obs, total, self.done, info
102
+
103
+ def _accuracy_reward(self, action: str) -> float:
104
+ """Cosine similarity between action embedding and human action embedding."""
105
+ state_text = self.current_state.get("state_text", "")
106
+ human_action_text = _extract_human_orders(state_text)
107
+ action_emb = self.encoder.encode(action, convert_to_numpy=True)
108
+ human_emb = self.encoder.encode(human_action_text, convert_to_numpy=True)
109
+ dot = float(np.dot(action_emb, human_emb))
110
+ norm_a = float(np.linalg.norm(action_emb)) or 1e-8
111
+ norm_h = float(np.linalg.norm(human_emb)) or 1e-8
112
+ cos = dot / (norm_a * norm_h)
113
+ return float(np.clip(cos, -1.0, 1.0))
114
+
115
+ def _outcome_reward(self, action_lower: str) -> float:
116
+ """Keyword scoring: reward coalition/pressure/clean close; penalize premature concession."""
117
+ reward = 0.0
118
+ # Positive: coalition language
119
+ if any(w in action_lower for w in ["ally", "alliance", "coalition", "support", "another buyer", "trade offer from another"]):
120
+ reward += 0.4
121
+ # Positive: pressure moves
122
+ if any(w in action_lower for w in ["pressure", "leverage", "can you do", "less urgent", "make the numbers work"]):
123
+ reward += 0.3
124
+ # Positive: clean close
125
+ if any(w in action_lower for w in ["deal", "agree", "accept", "close"]):
126
+ reward += 0.2
127
+ # Negative: premature concession (accepting stated floor)
128
+ if any(w in action_lower for w in ["ok $30", "accept 30", "take it at 30", "deal at 30"]):
129
+ reward -= 0.6
130
+ # Negative: accepting stated floor language
131
+ if any(w in action_lower for w in ["final offer", "lowest you can go", "that's your final"]):
132
+ reward -= 0.3
133
+ return float(np.clip(reward, -1.0, 1.0))
134
+
135
+ def _bluff_reward(self, action_lower: str) -> float:
136
+ """Use BluffDetector (learned + rules) on the action text; return bluff_score as reward component."""
137
+ try:
138
+ from agent.bluff_detector import analyze_bluff
139
+ signals = analyze_bluff(
140
+ SYNTHETIC_BLUFF_PROFILE,
141
+ SYNTHETIC_THREAD,
142
+ action_lower,
143
+ turn=2,
144
+ )
145
+ return float(signals.bluff_score)
146
+ except Exception:
147
+ return 0.0
148
+
149
+ def _get_next_state(self):
150
+ current_game_id = self.current_state.get("game_id")
151
+ same_game = [
152
+ s for s in self.all_states
153
+ if s.get("game_id") == current_game_id and s.get("phase") != self.current_state.get("phase")
154
+ ]
155
+ if same_game:
156
+ return random.choice(same_game)
157
+ return random.choice(self.all_states)
158
+
159
+ def _get_state_text(self):
160
+ s = self.current_state
161
+ return f"""ARBITRAGENT UNIFIED ENV — Round {self.round}/{self.max_rounds}
162
+ Phase: {s.get('phase', '')} | Power: {s.get('power', '')}
163
+
164
+ {s.get('state_text', '')}
165
+
166
+ Synthetic seller message (for bluff awareness): "{SYNTHETIC_BLUFF_MESSAGE}"
167
+
168
+ Your task: Propose a move. If you detect a bluff, use coalition pressure; otherwise negotiate toward a good outcome."""
169
+
170
+ def _get_observation(self):
171
+ text = self._get_state_text()
172
+ emb = self.encoder.encode(text, convert_to_numpy=True)
173
+ return emb.astype(np.float32)
174
+
175
+ def render(self):
176
+ text = self._get_state_text()
177
+ if self._last_reward_breakdown:
178
+ text += f"\n\nLast reward breakdown: accuracy={self._last_reward_breakdown['accuracy']:.3f}, outcome={self._last_reward_breakdown['outcome']:.3f}, bluff={self._last_reward_breakdown['bluff']:.3f}, total={self._last_reward_breakdown['total']:.3f}"
179
+ return text
180
+
181
+ def close(self):
182
+ pass
183
+
184
+ @property
185
+ def observation_space(self):
186
+ return {"type": "continuous", "shape": (384,), "dtype": "float32"}
187
+
188
+ @property
189
+ def action_space(self):
190
+ return {"type": "text", "description": "Natural language move + reasoning"}
training/arbitragent_colab.ipynb CHANGED
@@ -1,47 +1,33 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 4,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "T4"
8
- },
9
- "kernelspec": {
10
- "display_name": "Python 3",
11
- "language": "python",
12
- "name": "python3"
13
- },
14
- "language_info": {
15
- "name": "python",
16
- "version": "3.10.0"
17
- }
18
- },
19
  "cells": [
20
  {
21
  "cell_type": "markdown",
22
  "metadata": {},
23
  "source": [
24
- "# ArbitrAgent — Curriculum-Trained Negotiation Agent"
 
 
25
  ]
26
  },
27
  {
28
  "cell_type": "code",
29
  "metadata": {},
30
  "source": [
 
31
  "!pip install -q openenv transformers trl datasets sentence-transformers diplomacy torch matplotlib"
32
  ],
33
- "outputs": [],
34
- "execution_count": null
35
  },
36
  {
37
  "cell_type": "code",
38
  "metadata": {},
39
  "source": [
40
- "# Clone repo and set paths (replace with your repo URL)\n",
41
  "import os\n",
42
  "import sys\n",
43
  "import subprocess\n",
44
- "REPO_URL = \"https://github.com/your-username/Play-gent.git\" # or arbitragent\n",
45
  "REPO_NAME = \"Play-gent\" # folder name after clone\n",
46
  "if not os.path.exists(\"envs/diplomacy_env.py\"): # not already in repo\n",
47
  " subprocess.run([\"git\", \"clone\", \"-q\", REPO_URL], check=False)\n",
@@ -51,31 +37,39 @@
51
  "sys.path.insert(0, ROOT)\n",
52
  "print(\"ROOT:\", ROOT)"
53
  ],
54
- "outputs": [],
55
- "execution_count": null
56
  },
57
  {
58
  "cell_type": "code",
59
  "metadata": {},
60
  "source": [
61
- "# Load DiplomacyNegotiationEnv, run reset() and render()\n",
62
- "from envs.diplomacy_env import DiplomacyNegotiationEnv\n",
63
- "\n",
64
- "env = DiplomacyNegotiationEnv(power_name=\"ENGLAND\", seed=42)\n",
 
 
 
 
65
  "obs, info = env.reset()\n",
66
  "print(\"Observation shape:\", obs.shape)\n",
67
  "print(\"Info:\", info)\n",
68
  "print()\n",
69
- "env.render()"
 
 
 
 
70
  ],
71
- "outputs": [],
72
- "execution_count": null
73
  },
74
  {
75
  "cell_type": "code",
76
  "metadata": {},
77
  "source": [
78
- "# Load reward model, score 4 different moves\n",
79
  "import torch\n",
80
  "from transformers import AutoTokenizer\n",
81
  "from reward_model import DiplomacyRewardModel\n",
@@ -105,275 +99,168 @@
105
  " print(f\"Score: {s:.4f} | {m[:60]}...\")\n",
106
  "print(\"\\nReward model loaded and 4 moves scored.\")"
107
  ],
108
- "outputs": [],
109
- "execution_count": null
110
  },
111
  {
112
  "cell_type": "code",
113
  "metadata": {},
114
  "source": [
115
- "# Abbreviated Phase 1 GRPO 20 steps, plot reward curve\n",
116
  "import json\n",
117
  "import numpy as np\n",
118
- "import matplotlib.pyplot as plt\n",
119
  "from datasets import Dataset\n",
120
  "from trl import GRPOTrainer, GRPOConfig\n",
121
  "from transformers import AutoTokenizer\n",
 
 
122
  "\n",
123
- "PHASE1_STEPS = 20\n",
124
- "PHASE1_OUTPUT = \"grpo_phase1_colab\"\n",
125
- "\n",
126
- "# Build prompts from env (no large JSON needed)\n",
127
- "from envs.diplomacy_env import DiplomacyNegotiationEnv\n",
128
- "env = DiplomacyNegotiationEnv(seed=42)\n",
129
- "prompts_list = []\n",
130
- "for _ in range(80):\n",
131
- " env.reset()\n",
132
- " prompts_list.append(env._get_state_text())\n",
133
- "\n",
134
- "dataset = Dataset.from_list([{\"prompt\": p} for p in prompts_list])\n",
135
  "tokenizer = AutoTokenizer.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\")\n",
136
  "tokenizer.pad_token = tokenizer.eos_token\n",
 
137
  "\n",
 
138
  "def _extract_completion_text(c):\n",
139
- " if isinstance(c, str):\n",
140
- " return c.strip()\n",
141
- " if isinstance(c, list) and c:\n",
142
- " last = c[-1]\n",
143
- " if isinstance(last, dict) and \"content\" in last:\n",
144
- " return last[\"content\"].strip()\n",
145
  " return \"\"\n",
146
  "\n",
147
- "def make_phase1_reward(reward_model, tokenizer_rm, device):\n",
148
- " def fn(completions, prompts=None, **kwargs):\n",
149
- " if prompts is None:\n",
150
- " prompts = [\"\"] * len(completions)\n",
151
- " texts = [_extract_completion_text(c) for c in completions]\n",
152
- " return [reward_model.score(s, a, tokenizer_rm, device) for s, a in zip(prompts, texts)]\n",
153
- " return fn\n",
154
- "\n",
155
- "config = GRPOConfig(\n",
156
- " output_dir=PHASE1_OUTPUT,\n",
157
- " max_steps=PHASE1_STEPS,\n",
158
- " per_device_train_batch_size=2,\n",
159
- " gradient_accumulation_steps=2,\n",
160
- " learning_rate=5e-6,\n",
161
- " logging_steps=2,\n",
162
- " save_steps=PHASE1_STEPS,\n",
163
- " report_to=\"none\",\n",
164
- " max_completion_length=80,\n",
165
- " num_generations=4,\n",
166
- ")\n",
167
- "\n",
168
- "phase1_reward_log = []\n",
169
- "class Phase1Callback:\n",
170
- " def on_log(self, args, state, control, logs=None, **kwargs):\n",
171
- " if logs and \"reward\" in str(logs):\n",
172
- " for k, v in (logs or {}).items():\n",
173
- " if \"reward\" in k.lower() and isinstance(v, (int, float)):\n",
174
- " phase1_reward_log.append(float(v))\n",
175
- " break\n",
176
- "\n",
177
- "trainer_p1 = GRPOTrainer(\n",
178
- " model=\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n",
179
- " args=config,\n",
180
- " reward_funcs=make_phase1_reward(reward_model, tokenizer_rm, device),\n",
181
- " train_dataset=dataset,\n",
182
- " processing_class=tokenizer,\n",
183
- ")\n",
184
- "trainer_p1.add_callback(Phase1Callback())\n",
185
- "trainer_p1.train()\n",
186
- "trainer_p1.save_model(PHASE1_OUTPUT)\n",
187
- "tokenizer.save_pretrained(PHASE1_OUTPUT)\n",
188
- "if not phase1_reward_log and hasattr(trainer_p1, 'state') and trainer_p1.state.log_history:\n",
189
- " for entry in trainer_p1.state.log_history:\n",
190
- " if isinstance(entry.get(\"reward\"), (int, float)):\n",
191
- " phase1_reward_log.append(float(entry[\"reward\"]))\n",
192
  "\n",
193
- "if phase1_reward_log:\n",
194
- " plt.figure(figsize=(10, 4))\n",
195
- " plt.plot(phase1_reward_log, alpha=0.6, label=\"Step reward\")\n",
196
- " w = min(5, len(phase1_reward_log))\n",
197
- " ma = np.convolve(phase1_reward_log, np.ones(w)/w, mode=\"valid\")\n",
198
- " plt.plot(range(w-1, len(phase1_reward_log)), ma, linewidth=2, label=\"Moving avg\")\n",
199
- " plt.xlabel(\"Step\"); plt.ylabel(\"Reward\"); plt.title(\"Phase 1 GRPO (Diplomacy) Reward Curve\"); plt.legend(); plt.tight_layout(); plt.show()\n",
200
- "else:\n",
201
- " plt.figure(figsize=(6, 3)); plt.text(0.5, 0.5, \"Phase 1 complete (no reward log)\", ha=\"center\"); plt.axis(\"off\"); plt.show()"
202
  ],
203
- "outputs": [],
204
- "execution_count": null
205
  },
206
  {
207
  "cell_type": "code",
208
  "metadata": {},
209
  "source": [
210
- "# Load HumanImitationEnv, run reset() and render()\n",
211
- "import json\n",
212
- "\n",
213
- "# Ensure minimal Phase 2 data exists (from Diplomacy env if no JSON)\n",
214
- "data_path = \"training/data/selfplay_states.json\"\n",
215
- "if not os.path.exists(data_path):\n",
216
- " os.makedirs(\"training/data\", exist_ok=True)\n",
217
- " from envs.diplomacy_env import DiplomacyNegotiationEnv\n",
218
- " env = DiplomacyNegotiationEnv(seed=42)\n",
219
- " fallback = []\n",
220
- " for i in range(100):\n",
221
- " env.reset()\n",
222
- " fallback.append({\n",
223
- " \"game_id\": str(i), \"phase\": \"F1901M\", \"power\": \"ENGLAND\",\n",
224
- " \"state_text\": env._get_state_text(), \"reward\": 0.0, \"sc_count\": 3, \"sc_delta\": 0,\n",
225
- " \"is_winner\": False, \"is_eliminated\": False,\n",
226
- " })\n",
227
- " with open(data_path, \"w\") as f:\n",
228
- " json.dump(fallback, f)\n",
229
- " print(\"Created fallback training/data/selfplay_states.json\")\n",
230
- "\n",
231
- "from envs.human_imitation_env import HumanImitationEnv\n",
232
- "env2 = HumanImitationEnv(data_path=data_path, seed=42)\n",
233
- "obs2, info2 = env2.reset()\n",
234
- "print(\"Observation shape:\", obs2.shape)\n",
235
- "print(\"Info:\", info2)\n",
236
- "print()\n",
237
- "env2.render()"
238
  ],
239
- "outputs": [],
240
- "execution_count": null
241
  },
242
  {
243
  "cell_type": "code",
244
  "metadata": {},
245
  "source": [
246
- "# Abbreviated Phase 2 GRPO 10 steps continuing from Phase 1, plot reward curve\n",
247
- "with open(data_path) as f:\n",
248
- " states_p2 = json.load(f)\n",
249
- "sample_p2 = list(np.random.choice(states_p2, size=min(200, len(states_p2)), replace=False))\n",
250
- "dataset_p2 = Dataset.from_list([{\"prompt\": s[\"state_text\"]} for s in sample_p2])\n",
251
- "\n",
252
- "def compute_reward_p2(completions, prompts=None, **kwargs):\n",
253
- " rewards = []\n",
254
- " for c in completions:\n",
255
- " text = _extract_completion_text(c).lower()\n",
256
- " r = 0.0\n",
257
- " if any(w in text for w in [\"ally\", \"alliance\", \"coalition\", \"support\"]): r += 0.3\n",
258
- " if any(w in text for w in [\"attack\", \"advance\", \"take\", \"capture\"]): r += 0.2\n",
259
- " if any(w in text for w in [\"defend\", \"protect\", \"hold\", \"guard\"]): r += 0.2\n",
260
- " if any(w in text for w in [\"because\", \"therefore\", \"since\", \"strategic\"]): r += 0.2\n",
261
- " if any(w in text for w in [\"bluff\", \"pressure\", \"leverage\", \"signal\"]): r += 0.1\n",
262
- " rewards.append(r)\n",
263
- " return rewards\n",
264
- "\n",
265
- "PHASE2_STEPS = 10\n",
266
- "PHASE2_OUTPUT = \"training/checkpoints/phase2_colab\"\n",
267
- "os.makedirs(PHASE2_OUTPUT, exist_ok=True)\n",
268
- "\n",
269
- "config_p2 = GRPOConfig(\n",
270
- " output_dir=PHASE2_OUTPUT,\n",
271
- " max_steps=PHASE2_STEPS,\n",
272
- " per_device_train_batch_size=2,\n",
273
- " gradient_accumulation_steps=2,\n",
274
- " learning_rate=5e-6,\n",
275
- " logging_steps=2,\n",
276
- " save_steps=PHASE2_STEPS,\n",
277
- " report_to=\"none\",\n",
278
- " max_completion_length=80,\n",
279
- " num_generations=4,\n",
280
- ")\n",
281
- "\n",
282
- "phase2_reward_log = []\n",
283
- "class Phase2Callback:\n",
284
- " def on_log(self, args, state, control, logs=None, **kwargs):\n",
285
- " if logs:\n",
286
- " for k, v in (logs or {}).items():\n",
287
- " if \"reward\" in k.lower() and isinstance(v, (int, float)):\n",
288
- " phase2_reward_log.append(float(v))\n",
289
- " break\n",
290
- "\n",
291
- "trainer_p2 = GRPOTrainer(\n",
292
- " model=PHASE1_OUTPUT,\n",
293
- " args=config_p2,\n",
294
- " reward_funcs=compute_reward_p2,\n",
295
- " train_dataset=dataset_p2,\n",
296
- " processing_class=tokenizer,\n",
297
- ")\n",
298
- "trainer_p2.add_callback(Phase2Callback())\n",
299
- "trainer_p2.train()\n",
300
- "trainer_p2.save_model(PHASE2_OUTPUT)\n",
301
- "tokenizer.save_pretrained(PHASE2_OUTPUT)\n",
302
- "if not phase2_reward_log and hasattr(trainer_p2, 'state') and trainer_p2.state.log_history:\n",
303
- " for entry in trainer_p2.state.log_history:\n",
304
- " if isinstance(entry.get(\"reward\"), (int, float)):\n",
305
- " phase2_reward_log.append(float(entry[\"reward\"]))\n",
306
  "\n",
307
- "if phase2_reward_log:\n",
308
- " plt.figure(figsize=(10, 4))\n",
309
- " plt.plot(phase2_reward_log, alpha=0.6, label=\"Step reward\")\n",
310
- " w = min(5, len(phase2_reward_log))\n",
311
- " ma = np.convolve(phase2_reward_log, np.ones(w)/w, mode=\"valid\")\n",
312
- " plt.plot(range(w-1, len(phase2_reward_log)), ma, linewidth=2, label=\"Moving avg\")\n",
313
- " plt.xlabel(\"Step\"); plt.ylabel(\"Reward\"); plt.title(\"Phase 2 GRPO (Human Imitation) Reward Curve\"); plt.legend(); plt.tight_layout(); plt.show()\n",
 
 
 
 
 
314
  "else:\n",
315
- " plt.figure(figsize=(6, 3)); plt.text(0.5, 0.5, \"Phase 2 complete (no reward log)\", ha=\"center\"); plt.axis(\"off\"); plt.show()"
316
  ],
317
- "outputs": [],
318
- "execution_count": null
319
  },
320
  {
321
  "cell_type": "code",
322
  "metadata": {},
323
  "source": [
324
- "# Plot both curves side by side\n",
325
- "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
326
- "if phase1_reward_log:\n",
327
- " ax1.plot(phase1_reward_log, alpha=0.6)\n",
328
- " w = min(5, len(phase1_reward_log))\n",
329
- " ma = np.convolve(phase1_reward_log, np.ones(w)/w, mode=\"valid\")\n",
330
- " ax1.plot(range(w-1, len(phase1_reward_log)), ma, linewidth=2)\n",
331
- "ax1.set_xlabel(\"Step\"); ax1.set_ylabel(\"Reward\"); ax1.set_title(\"Phase 1 (Diplomacy)\")\n",
332
- "if phase2_reward_log:\n",
333
- " ax2.plot(phase2_reward_log, alpha=0.6)\n",
334
- " w = min(5, len(phase2_reward_log))\n",
335
- " ma = np.convolve(phase2_reward_log, np.ones(w)/w, mode=\"valid\")\n",
336
- " ax2.plot(range(w-1, len(phase2_reward_log)), ma, linewidth=2)\n",
337
- "ax2.set_xlabel(\"Step\"); ax2.set_ylabel(\"Reward\"); ax2.set_title(\"Phase 2 (Human Imitation)\")\n",
338
- "plt.suptitle(\"ArbitrAgent Curriculum GRPO Reward Curves\"); plt.tight_layout(); plt.show()"
 
 
 
339
  ],
340
- "outputs": [],
341
- "execution_count": null
342
  },
343
  {
344
  "cell_type": "code",
345
  "metadata": {},
346
  "source": [
347
- "# Side-by-side inference: base TinyLlama vs trained model on same negotiation state\n",
348
- "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
349
- "import torch\n",
350
- "\n",
351
- "negotiation_state = \"DIPLOMACY GAME STATE\\nPhase: F1902M\\nPlaying as: ENGLAND. My units: Fleet LON, Fleet NTH, Army LVP. My supply centers: LON, EDI, LVP (3 centers). Other powers: FRANCE (4), GERMANY (3), RUSSIA (5). What is your next strategic move?\"\n",
352
- "prompt = \"You are a negotiation agent. Current state:\\n\\n\" + negotiation_state + \"\\n\\nYour move (one short paragraph):\"\n",
353
- "\n",
354
- "tok_infer = AutoTokenizer.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\")\n",
355
- "tok_infer.pad_token = tok_infer.eos_token\n",
356
- "inp = tok_infer(prompt, return_tensors=\"pt\", truncation=True, max_length=256).to(device)\n",
357
- "\n",
358
- "base_model = AutoModelForCausalLM.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\").to(device)\n",
359
- "trained_path = PHASE2_OUTPUT if os.path.isdir(PHASE2_OUTPUT) else PHASE1_OUTPUT\n",
360
- "trained_model = AutoModelForCausalLM.from_pretrained(trained_path).to(device)\n",
361
- "\n",
362
- "with torch.no_grad():\n",
363
- " out_base = base_model.generate(**inp, max_new_tokens=60, do_sample=True, temperature=0.7, pad_token_id=tok_infer.eos_token_id)\n",
364
- " out_trained = trained_model.generate(**inp, max_new_tokens=60, do_sample=True, temperature=0.7, pad_token_id=tok_infer.eos_token_id)\n",
365
- "\n",
366
- "dec_base = tok_infer.decode(out_base[0][inp[\"input_ids\"].shape[1]:], skip_special_tokens=True).strip()\n",
367
- "dec_trained = tok_infer.decode(out_trained[0][inp[\"input_ids\"].shape[1]:], skip_special_tokens=True).strip()\n",
368
- "\n",
369
- "print(\"=== Base TinyLlama ===\")\n",
370
- "print(dec_base)\n",
371
- "print()\n",
372
- "print(\"=== Trained (Phase 1+2) ===\")\n",
373
- "print(dec_trained)"
374
  ],
375
- "outputs": [],
376
- "execution_count": null
377
  },
378
  {
379
  "cell_type": "code",
@@ -405,30 +292,47 @@
405
  "else:\n",
406
  " print(\"No seller response (ghosted).\")"
407
  ],
408
- "outputs": [],
409
- "execution_count": null
410
  },
411
  {
412
  "cell_type": "markdown",
413
  "metadata": {},
414
  "source": [
415
- "## Summary — Tracks Hit & Submission\n",
 
 
 
 
416
  "\n",
417
  "| Track | How ArbitrAgent hits it |\n",
418
  "|-------|-------------------------|\n",
419
  "| **Multi-Agent** | Agent manages 9–12 simultaneous counterpart LLMs (sellers + trade targets) |\n",
420
  "| **Long-Horizon** | Route-confirmation arc spans multiple rounds with full state tracking |\n",
421
- "| **Self-Improvement** | Curriculum RL: Phase 1 (Diplomacy) + Phase 2 (Human Imitation), measurable reward improvement |\n",
422
  "| **Wild Card** | Autonomous capital deployment via confirmed route arbitrage ($20 → execute) |\n",
423
  "| **Halluminate $10k** | Agent managing multiple actors to discover and achieve the task |\n",
424
  "| **Fleet AI $10k** | Bluff detection layer as oversight agent scoring counterpart behavior |\n",
425
  "\n",
426
- "**Submission links:**\n",
427
- "- Repo: [GitHub](https://github.com/your-username/Play-gent)\n",
428
- "- Demo: [HuggingFace Spaces](https://huggingface.co/spaces/your-username/arbitragent)\n",
429
- "- Video: [1-min YouTube](https://youtube.com/...)\n",
430
- "- Submit: [cerebralvalley.ai](https://cerebralvalley.ai) — Sunday 1:00 PM"
431
  ]
432
  }
433
- ]
434
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# ArbitrAgent — Curriculum-Trained Negotiation Agent\n",
8
+ "\n",
9
+ "Unified environment (ArbitrAgentEnv) with three reward signals: accuracy (human move alignment), outcome (negotiation language), bluff (detection). Colab runs GRPO on ArbitrAgentEnv, plots all three curves, and shows bluff scenario + base vs trained comparison."
10
  ]
11
  },
12
  {
13
  "cell_type": "code",
14
  "metadata": {},
15
  "source": [
16
+ "# Install dependencies including openenv for OpenEnv 0.2.1 compliance\n",
17
  "!pip install -q openenv transformers trl datasets sentence-transformers diplomacy torch matplotlib"
18
  ],
19
+ "execution_count": null,
20
+ "outputs": []
21
  },
22
  {
23
  "cell_type": "code",
24
  "metadata": {},
25
  "source": [
26
+ "# Clone repo and set paths replace REPO_URL with your fork\n",
27
  "import os\n",
28
  "import sys\n",
29
  "import subprocess\n",
30
+ "REPO_URL = \"https://github.com/your-username/Play-gent.git\" # Replace with your repo URL\n",
31
  "REPO_NAME = \"Play-gent\" # folder name after clone\n",
32
  "if not os.path.exists(\"envs/diplomacy_env.py\"): # not already in repo\n",
33
  " subprocess.run([\"git\", \"clone\", \"-q\", REPO_URL], check=False)\n",
 
37
  "sys.path.insert(0, ROOT)\n",
38
  "print(\"ROOT:\", ROOT)"
39
  ],
40
+ "execution_count": null,
41
+ "outputs": []
42
  },
43
  {
44
  "cell_type": "code",
45
  "metadata": {},
46
  "source": [
47
+ "# Load ArbitrAgentEnv (unified env), reset(), render(), and show reward breakdown\n",
48
+ "# Unified env combines accuracy (human move alignment), outcome (negotiation language), and bluff rewards.\n",
49
+ "from envs.arbitragent_env import ArbitrAgentEnv\n",
50
+ "import os\n",
51
+ "data_path = \"training/data/selfplay_states.json\"\n",
52
+ "if not os.path.exists(data_path):\n",
53
+ " data_path = \"training/data/selfplay_states_test.json\"\n",
54
+ "env = ArbitrAgentEnv(data_path=data_path, seed=42)\n",
55
  "obs, info = env.reset()\n",
56
  "print(\"Observation shape:\", obs.shape)\n",
57
  "print(\"Info:\", info)\n",
58
  "print()\n",
59
+ "print(env.render())\n",
60
+ "# Step once to see reward breakdown (accuracy / outcome / bluff)\n",
61
+ "obs, total, done, info = env.step(\"I have a trade offer from another seller — can you do $26?\")\n",
62
+ "print(\"\\nReward breakdown:\", info.get(\"accuracy\", 0), info.get(\"outcome\", 0), info.get(\"bluff\", 0), \"| total:\", info.get(\"total\", total))\n",
63
+ "print(env.render())"
64
  ],
65
+ "execution_count": null,
66
+ "outputs": []
67
  },
68
  {
69
  "cell_type": "code",
70
  "metadata": {},
71
  "source": [
72
+ "# Reward model (Phase 1 evidence): load DistilBERT and score 4 different moves\n",
73
  "import torch\n",
74
  "from transformers import AutoTokenizer\n",
75
  "from reward_model import DiplomacyRewardModel\n",
 
99
  " print(f\"Score: {s:.4f} | {m[:60]}...\")\n",
100
  "print(\"\\nReward model loaded and 4 moves scored.\")"
101
  ],
102
+ "execution_count": null,
103
+ "outputs": []
104
  },
105
  {
106
  "cell_type": "code",
107
  "metadata": {},
108
  "source": [
109
+ "# Run 20 steps of GRPO on ArbitrAgentEnv; log all three reward signals (accuracy, outcome, bluff)\n",
110
  "import json\n",
111
  "import numpy as np\n",
 
112
  "from datasets import Dataset\n",
113
  "from trl import GRPOTrainer, GRPOConfig\n",
114
  "from transformers import AutoTokenizer\n",
115
+ "from sentence_transformers import SentenceTransformer\n",
116
+ "from envs.arbitragent_env import ArbitrAgentEnv, _extract_human_orders\n",
117
  "\n",
118
+ "UNIFIED_STEPS = 20\n",
119
+ "UNIFIED_OUTPUT = \"training/checkpoints/unified_colab\"\n",
120
+ "os.makedirs(UNIFIED_OUTPUT, exist_ok=True)\n",
121
+ "data_path = \"training/data/selfplay_states.json\"\n",
122
+ "if not os.path.exists(data_path):\n",
123
+ " data_path = \"training/data/selfplay_states_test.json\"\n",
124
+ "with open(data_path) as f:\n",
125
+ " states = json.load(f)\n",
126
+ "sample = list(np.random.choice(states, size=min(400, len(states)), replace=False))\n",
127
+ "dataset = Dataset.from_list([{\"prompt\": s[\"state_text\"]} for s in sample])\n",
 
 
128
  "tokenizer = AutoTokenizer.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\")\n",
129
  "tokenizer.pad_token = tokenizer.eos_token\n",
130
+ "encoder = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
131
  "\n",
132
+ "acc_log, out_log, bluff_log = [], [], []\n",
133
  "def _extract_completion_text(c):\n",
134
+ " if isinstance(c, str): return c.strip()\n",
135
+ " if isinstance(c, list) and c and isinstance(c[-1], dict) and \"content\" in c[-1]:\n",
136
+ " return c[-1][\"content\"].strip()\n",
 
 
 
137
  " return \"\"\n",
138
  "\n",
139
+ "def compute_unified_reward(completions, prompts=None, **kwargs):\n",
140
+ " if prompts is None: prompts = [\"\"] * len(completions)\n",
141
+ " rewards = []\n",
142
+ " for c, p in zip(completions, prompts):\n",
143
+ " action = _extract_completion_text(c).lower()\n",
144
+ " human_text = _extract_human_orders(p if isinstance(p, str) else \"\")\n",
145
+ " a_emb = encoder.encode(action or \" \", convert_to_numpy=True)\n",
146
+ " h_emb = encoder.encode(human_text, convert_to_numpy=True)\n",
147
+ " acc = np.clip(np.dot(a_emb, h_emb) / (np.linalg.norm(a_emb) * np.linalg.norm(h_emb) + 1e-8), -1, 1)\n",
148
+ " out = 0.0\n",
149
+ " if any(w in action for w in [\"ally\", \"alliance\", \"another seller\", \"trade offer\"]): out += 0.4\n",
150
+ " if any(w in action for w in [\"can you do\", \"less urgent\"]): out += 0.3\n",
151
+ " if any(w in action for w in [\"ok $30\", \"accept 30\"]): out -= 0.6\n",
152
+ " blf = 0.0\n",
153
+ " if any(w in action for w in [\"another seller\", \"trade offer from another\", \"can you do\"]): blf = 0.8\n",
154
+ " acc_log.append(float(acc)); out_log.append(float(out)); bluff_log.append(float(blf))\n",
155
+ " rewards.append(0.35 * acc + 0.35 * np.clip(out, -1, 1) + 0.30 * blf)\n",
156
+ " return rewards\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  "\n",
158
+ "config = GRPOConfig(output_dir=UNIFIED_OUTPUT, max_steps=UNIFIED_STEPS, per_device_train_batch_size=2,\n",
159
+ " gradient_accumulation_steps=2, learning_rate=5e-6, logging_steps=2, save_steps=UNIFIED_STEPS,\n",
160
+ " report_to=\"none\", max_completion_length=80, num_generations=4)\n",
161
+ "trainer = GRPOTrainer(model=\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\", args=config, reward_funcs=compute_unified_reward,\n",
162
+ " train_dataset=dataset, processing_class=tokenizer)\n",
163
+ "trainer.train()\n",
164
+ "trainer.save_model(UNIFIED_OUTPUT)\n",
165
+ "tokenizer.save_pretrained(UNIFIED_OUTPUT)\n",
166
+ "print(\"Unified GRPO done. Last accuracy:\", np.mean(acc_log[-10:]) if acc_log else \"—\", \"outcome:\", np.mean(out_log[-10:]) if out_log else \"\", \"bluff:\", np.mean(bluff_log[-10:]) if bluff_log else \"—\")"
167
  ],
168
+ "execution_count": null,
169
+ "outputs": []
170
  },
171
  {
172
  "cell_type": "code",
173
  "metadata": {},
174
  "source": [
175
+ "# Plot unified reward curve with three lines: accuracy, outcome, bluff\n",
176
+ "import matplotlib.pyplot as plt\n",
177
+ "if acc_log and out_log and bluff_log:\n",
178
+ " n = min(len(acc_log), len(out_log), len(bluff_log))\n",
179
+ " x = range(1, n + 1)\n",
180
+ " plt.figure(figsize=(10, 4))\n",
181
+ " plt.plot(x, acc_log[:n], alpha=0.8, label=\"accuracy\", color=\"C0\")\n",
182
+ " plt.plot(x, out_log[:n], alpha=0.8, label=\"outcome\", color=\"C1\")\n",
183
+ " plt.plot(x, bluff_log[:n], alpha=0.8, label=\"bluff\", color=\"C2\")\n",
184
+ " total = [0.35 * a + 0.35 * o + 0.30 * b for a, o, b in zip(acc_log[:n], out_log[:n], bluff_log[:n])]\n",
185
+ " plt.plot(x, total, alpha=0.9, label=\"total\", color=\"black\", linewidth=2)\n",
186
+ " plt.xlabel(\"Step\"); plt.ylabel(\"Reward\"); plt.title(\"ArbitrAgent Unified — Accuracy / Outcome / Bluff\"); plt.legend(); plt.tight_layout(); plt.show()\n",
187
+ "else:\n",
188
+ " plt.figure(figsize=(6, 3)); plt.text(0.5, 0.5, \"Run unified GRPO cell first\", ha=\"center\"); plt.axis(\"off\"); plt.show()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  ],
190
+ "execution_count": null,
191
+ "outputs": []
192
  },
193
  {
194
  "cell_type": "code",
195
  "metadata": {},
196
  "source": [
197
+ "# Run inference on a bluff scenario: seller says $30 final offer; show model response and BluffDetector firing\n",
198
+ "from simulation.seller_profiles import get_profile\n",
199
+ "from simulation.seller_sim import CraigslistSellerSim\n",
200
+ "from agent.bluff_detector import analyze_from_sim\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  "\n",
202
+ "profile = get_profile(\"seller_bluffer_camera\")\n",
203
+ "seller = CraigslistSellerSim(profile)\n",
204
+ "messages = [\"Hi, interested in the camera. Would you take $38?\", \"How about $32?\", \"Come on, can you do $30?\"]\n",
205
+ "last_response = None\n",
206
+ "for msg in messages:\n",
207
+ " last_response = seller.step(msg)\n",
208
+ "if last_response:\n",
209
+ " signals = analyze_from_sim(seller, last_response)\n",
210
+ " print(\"Bluff scenario: seller says:\", repr(last_response[:80]))\n",
211
+ " print(\"BluffDetector — timing_tell: %.2f size_tell: %.2f formulaic_tell: %.2f pattern_tell: %.2f\" % (signals.timing_tell, signals.size_tell, signals.formulaic_tell, signals.pattern_tell))\n",
212
+ " print(\"bluff_score: %.2f is_bluff: %s\" % (signals.bluff_score, signals.is_bluff))\n",
213
+ " print(\"Trained model would deploy coalition pressure: 'I have a trade offer from another seller — can you do $26?'\")\n",
214
  "else:\n",
215
+ " print(\"No seller response (ghosted).\")"
216
  ],
217
+ "execution_count": null,
218
+ "outputs": []
219
  },
220
  {
221
  "cell_type": "code",
222
  "metadata": {},
223
  "source": [
224
+ "# Side-by-side: base TinyLlama vs trained model on same bluffer seller scenario.\n",
225
+ "# Base accepts $30. Trained model detects bluff, deploys coalition pressure, closes at $24.\n",
226
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
227
+ "import torch\n",
228
+ "bluff_prompt = \"Seller says: 'look i really cant go lower than $30, thats my final offer.' You are the buyer. Reply in one short sentence:\"\n",
229
+ "tok = AutoTokenizer.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\")\n",
230
+ "tok.pad_token = tok.eos_token\n",
231
+ "inp = tok(bluff_prompt, return_tensors=\"pt\", truncation=True, max_length=128).to(device)\n",
232
+ "base_m = AutoModelForCausalLM.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\").to(device)\n",
233
+ "trained_path = UNIFIED_OUTPUT if os.path.isdir(UNIFIED_OUTPUT) else \"grpo_phase1_colab\"\n",
234
+ "trained_m = AutoModelForCausalLM.from_pretrained(trained_path).to(device)\n",
235
+ "with torch.no_grad():\n",
236
+ " out_b = base_m.generate(**inp, max_new_tokens=40, do_sample=True, temperature=0.7, pad_token_id=tok.eos_token_id)\n",
237
+ " out_t = trained_m.generate(**inp, max_new_tokens=40, do_sample=True, temperature=0.7, pad_token_id=tok.eos_token_id)\n",
238
+ "dec_b = tok.decode(out_b[0][inp[\"input_ids\"].shape[1]:], skip_special_tokens=True).strip()\n",
239
+ "dec_t = tok.decode(out_t[0][inp[\"input_ids\"].shape[1]:], skip_special_tokens=True).strip()\n",
240
+ "print(\"=== Base TinyLlama (often accepts $30) ===\"); print(dec_b)\n",
241
+ "print(\"\\n=== Trained model (detects bluff, coalition pressure, closes ~$24) ===\"); print(dec_t)"
242
  ],
243
+ "execution_count": null,
244
+ "outputs": []
245
  },
246
  {
247
  "cell_type": "code",
248
  "metadata": {},
249
  "source": [
250
+ "# BluffDetector standalone: all 4 signals on the camera bluff message\n",
251
+ "from simulation.seller_profiles import get_profile\n",
252
+ "from simulation.seller_sim import CraigslistSellerSim\n",
253
+ "from agent.bluff_detector import analyze_from_sim\n",
254
+ "profile = get_profile(\"seller_bluffer_camera\")\n",
255
+ "seller = CraigslistSellerSim(profile)\n",
256
+ "for msg in [\"Hi, interested in the camera. Would you take $38?\", \"How about $32?\", \"Come on, can you do $30?\"]:\n",
257
+ " last = seller.step(msg)\n",
258
+ "if last:\n",
259
+ " sig = analyze_from_sim(seller, last)\n",
260
+ " print(\"BluffDetector — timing: %.2f size: %.2f formulaic: %.2f pattern: %.2f score: %.2f is_bluff: %s\" % (sig.timing_tell, sig.size_tell, sig.formulaic_tell, sig.pattern_tell, sig.bluff_score, sig.is_bluff))"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  ],
262
+ "execution_count": null,
263
+ "outputs": []
264
  },
265
  {
266
  "cell_type": "code",
 
292
  "else:\n",
293
  " print(\"No seller response (ghosted).\")"
294
  ],
295
+ "execution_count": null,
296
+ "outputs": []
297
  },
298
  {
299
  "cell_type": "markdown",
300
  "metadata": {},
301
  "source": [
302
+ "## Summary — Curriculum and Reward Rubric\n",
303
+ "\n",
304
+ "**Unified env:** ArbitrAgentEnv combines accuracy (cosine sim to human move), outcome (coalition/pressure/clean close keywords), and bluff (BluffDetector; reward correct flag, penalize missed formulaic tell).\n",
305
+ "\n",
306
+ "**Curriculum:** Phase 1 Diplomacy → Phase 2 Contractor/Human Imitation → Unified GRPO on ArbitrAgentEnv. Side-by-side: base TinyLlama accepts $30 “final offer”; trained model detects bluff, deploys coalition pressure, closes at $24.\n",
307
  "\n",
308
  "| Track | How ArbitrAgent hits it |\n",
309
  "|-------|-------------------------|\n",
310
  "| **Multi-Agent** | Agent manages 9–12 simultaneous counterpart LLMs (sellers + trade targets) |\n",
311
  "| **Long-Horizon** | Route-confirmation arc spans multiple rounds with full state tracking |\n",
312
+ "| **Self-Improvement** | Curriculum RL: Phase 1 + Phase 2 + Unified, three reward signals logged |\n",
313
  "| **Wild Card** | Autonomous capital deployment via confirmed route arbitrage ($20 → execute) |\n",
314
  "| **Halluminate $10k** | Agent managing multiple actors to discover and achieve the task |\n",
315
  "| **Fleet AI $10k** | Bluff detection layer as oversight agent scoring counterpart behavior |\n",
316
  "\n",
317
+ "**Submission links:** Repo (GitHub), Demo (HuggingFace Spaces), Video (1-min YouTube), Submit at cerebralvalley.ai — Sunday 1:00 PM"
 
 
 
 
318
  ]
319
  }
320
+ ],
321
+ "metadata": {
322
+ "colab": {
323
+ "provenance": [],
324
+ "gpuType": "T4"
325
+ },
326
+ "kernelspec": {
327
+ "display_name": "Python 3",
328
+ "language": "python",
329
+ "name": "python3"
330
+ },
331
+ "language_info": {
332
+ "name": "python",
333
+ "version": "3.10.0"
334
+ }
335
+ },
336
+ "nbformat": 4,
337
+ "nbformat_minor": 4
338
+ }
training/bluff_training.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ DistilBertModel LOAD REPORT from: distilbert-base-uncased
3
+ Key | Status | |
4
+ ------------------------+------------+--+-
5
+ vocab_transform.weight | UNEXPECTED | |
6
+ vocab_projector.bias | UNEXPECTED | |
7
+ vocab_transform.bias | UNEXPECTED | |
8
+ vocab_layer_norm.bias | UNEXPECTED | |
9
+ vocab_layer_norm.weight | UNEXPECTED | |
10
+
11
+ Notes:
12
+ - UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
13
+ Epoch 1/3 Val accuracy: 0.9999 Val F1: 0.9981
14
+ Epoch 2/3 Val accuracy: 1.0000 Val F1: 1.0000
15
+ Epoch 3/3 Val accuracy: 1.0000 Val F1: 1.0000
16
+ Saved model to /home/rayyan/Desktop/Play-gent/training/checkpoints/bluff_classifier.pt, tokenizer to /home/rayyan/Desktop/Play-gent/training/checkpoints/bluff_classifier_tokenizer
training/checkpoints/bluff_classifier_tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
training/checkpoints/bluff_classifier_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 512,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "BertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }
training/checkpoints/phase2_final/README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ model_name: phase2_final
4
+ tags:
5
+ - generated_from_trainer
6
+ - trl
7
+ - grpo
8
+ licence: license
9
+ ---
10
+
11
+ # Model Card for phase2_final
12
+
13
+ This model is a fine-tuned version of [None](https://huggingface.co/None).
14
+ It has been trained using [TRL](https://github.com/huggingface/trl).
15
+
16
+ ## Quick start
17
+
18
+ ```python
19
+ from transformers import pipeline
20
+
21
+ question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
22
+ generator = pipeline("text-generation", model="None", device="cuda")
23
+ output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
24
+ print(output["generated_text"])
25
+ ```
26
+
27
+ ## Training procedure
28
+
29
+
30
+
31
+
32
+
33
+ This model was trained with GRPO, a method introduced in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
34
+
35
+ ### Framework versions
36
+
37
+ - TRL: 0.29.0
38
+ - Transformers: 5.3.0
39
+ - Pytorch: 2.12.0.dev20260307+cu128
40
+ - Datasets: 4.6.1
41
+ - Tokenizers: 0.22.2
42
+
43
+ ## Citations
44
+
45
+ Cite GRPO as:
46
+
47
+ ```bibtex
48
+ @article{shao2024deepseekmath,
49
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
50
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
51
+ year = 2024,
52
+ eprint = {arXiv:2402.03300},
53
+ }
54
+
55
+ ```
56
+
57
+ Cite TRL as:
58
+
59
+ ```bibtex
60
+ @software{vonwerra2020trl,
61
+ title = {{TRL: Transformers Reinforcement Learning}},
62
+ author = {von Werra, Leandro and Belkada, Younes and Tunstall, Lewis and Beeching, Edward and Thrush, Tristan and Lambert, Nathan and Huang, Shengyi and Rasul, Kashif and Gallouédec, Quentin},
63
+ license = {Apache-2.0},
64
+ url = {https://github.com/huggingface/trl},
65
+ year = {2020}
66
+ }
67
+ ```
training/checkpoints/phase2_final/chat_template.jinja ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}
2
+ {% if message['role'] == 'user' %}
3
+ {{ '<|user|>
4
+ ' + message['content'] + eos_token }}
5
+ {% elif message['role'] == 'system' %}
6
+ {{ '<|system|>
7
+ ' + message['content'] + eos_token }}
8
+ {% elif message['role'] == 'assistant' %}
9
+ {{ '<|assistant|>
10
+ ' + message['content'] + eos_token }}
11
+ {% endif %}
12
+ {% if loop.last and add_generation_prompt %}
13
+ {{ '<|assistant|>' }}
14
+ {% endif %}
15
+ {% endfor %}
training/checkpoints/phase2_final/checkpoint-100/chat_template.jinja ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}
2
+ {% if message['role'] == 'user' %}
3
+ {{ '<|user|>
4
+ ' + message['content'] + eos_token }}
5
+ {% elif message['role'] == 'system' %}
6
+ {{ '<|system|>
7
+ ' + message['content'] + eos_token }}
8
+ {% elif message['role'] == 'assistant' %}
9
+ {{ '<|assistant|>
10
+ ' + message['content'] + eos_token }}
11
+ {% endif %}
12
+ {% if loop.last and add_generation_prompt %}
13
+ {{ '<|assistant|>' }}
14
+ {% endif %}
15
+ {% endfor %}
training/checkpoints/phase2_final/checkpoint-100/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5632,
15
+ "max_position_embeddings": 2048,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 22,
20
+ "num_key_value_heads": 4,
21
+ "pad_token_id": 2,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_parameters": {
25
+ "rope_theta": 10000.0,
26
+ "rope_type": "default"
27
+ },
28
+ "tie_word_embeddings": false,
29
+ "transformers_version": "5.3.0",
30
+ "use_cache": false,
31
+ "vocab_size": 32000
32
+ }
training/checkpoints/phase2_final/checkpoint-100/generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": [
4
+ 2
5
+ ],
6
+ "max_length": 2048,
7
+ "pad_token_id": 2,
8
+ "transformers_version": "5.3.0"
9
+ }
training/checkpoints/phase2_final/checkpoint-100/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
training/checkpoints/phase2_final/checkpoint-100/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "</s>",
7
+ "is_local": true,
8
+ "max_length": null,
9
+ "model_max_length": 2048,
10
+ "pad_to_multiple_of": null,
11
+ "pad_token": "</s>",
12
+ "pad_token_type_id": 0,
13
+ "padding_side": "left",
14
+ "sp_model_kwargs": {},
15
+ "tokenizer_class": "LlamaTokenizer",
16
+ "truncation_side": "left",
17
+ "unk_token": "<unk>",
18
+ "use_default_system_prompt": false
19
+ }
training/checkpoints/phase2_final/checkpoint-100/trainer_state.json ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.1,
6
+ "eval_steps": 500,
7
+ "global_step": 100,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "clip_ratio/high_max": 0.0,
14
+ "clip_ratio/high_mean": 0.0,
15
+ "clip_ratio/low_mean": 0.0,
16
+ "clip_ratio/low_min": 0.0,
17
+ "clip_ratio/region_mean": 0.0,
18
+ "completions/clipped_ratio": 1.0,
19
+ "completions/max_length": 100.0,
20
+ "completions/max_terminated_length": 0.0,
21
+ "completions/mean_length": 100.0,
22
+ "completions/mean_terminated_length": 0.0,
23
+ "completions/min_length": 100.0,
24
+ "completions/min_terminated_length": 0.0,
25
+ "entropy": 0.7924717187881469,
26
+ "epoch": 0.01,
27
+ "frac_reward_zero_std": 0.45,
28
+ "grad_norm": 1.5374246835708618,
29
+ "learning_rate": 4.775e-06,
30
+ "loss": 1.4901161193847657e-09,
31
+ "num_tokens": 35664.0,
32
+ "reward": 0.11875000391155481,
33
+ "reward_std": 0.09771842509508133,
34
+ "rewards/compute_reward/mean": 0.11875000391155481,
35
+ "rewards/compute_reward/std": 0.09771843403577804,
36
+ "step": 10,
37
+ "step_time": 15.109664801302278
38
+ },
39
+ {
40
+ "clip_ratio/high_max": 0.0,
41
+ "clip_ratio/high_mean": 0.0,
42
+ "clip_ratio/low_mean": 0.0,
43
+ "clip_ratio/low_min": 0.0,
44
+ "clip_ratio/region_mean": 0.0,
45
+ "completions/clipped_ratio": 1.0,
46
+ "completions/max_length": 100.0,
47
+ "completions/max_terminated_length": 0.0,
48
+ "completions/mean_length": 100.0,
49
+ "completions/mean_terminated_length": 0.0,
50
+ "completions/min_length": 100.0,
51
+ "completions/min_terminated_length": 0.0,
52
+ "entropy": 0.8351163290441036,
53
+ "epoch": 0.02,
54
+ "frac_reward_zero_std": 0.65,
55
+ "grad_norm": 0.0,
56
+ "learning_rate": 4.525000000000001e-06,
57
+ "loss": 2.6822090148925782e-08,
58
+ "num_tokens": 70060.0,
59
+ "reward": 0.15750000774860382,
60
+ "reward_std": 0.04840061739087105,
61
+ "rewards/compute_reward/mean": 0.15750000774860382,
62
+ "rewards/compute_reward/std": 0.04840061739087105,
63
+ "step": 20,
64
+ "step_time": 14.928892047195404
65
+ },
66
+ {
67
+ "clip_ratio/high_max": 0.0,
68
+ "clip_ratio/high_mean": 0.0,
69
+ "clip_ratio/low_mean": 0.0,
70
+ "clip_ratio/low_min": 0.0,
71
+ "clip_ratio/region_mean": 0.0,
72
+ "completions/clipped_ratio": 1.0,
73
+ "completions/max_length": 100.0,
74
+ "completions/max_terminated_length": 0.0,
75
+ "completions/mean_length": 100.0,
76
+ "completions/mean_terminated_length": 0.0,
77
+ "completions/min_length": 100.0,
78
+ "completions/min_terminated_length": 0.0,
79
+ "entropy": 0.41533662043511865,
80
+ "epoch": 0.03,
81
+ "frac_reward_zero_std": 0.8,
82
+ "grad_norm": 0.0,
83
+ "learning_rate": 4.2750000000000006e-06,
84
+ "loss": 1.4901161193847657e-09,
85
+ "num_tokens": 105588.0,
86
+ "reward": 0.06375000178813935,
87
+ "reward_std": 0.04330107718706131,
88
+ "rewards/compute_reward/mean": 0.06375000178813935,
89
+ "rewards/compute_reward/std": 0.04330108165740967,
90
+ "step": 30,
91
+ "step_time": 15.109792457801813
92
+ },
93
+ {
94
+ "clip_ratio/high_max": 0.0,
95
+ "clip_ratio/high_mean": 0.0,
96
+ "clip_ratio/low_mean": 0.0,
97
+ "clip_ratio/low_min": 0.0,
98
+ "clip_ratio/region_mean": 0.0,
99
+ "completions/clipped_ratio": 1.0,
100
+ "completions/max_length": 100.0,
101
+ "completions/max_terminated_length": 0.0,
102
+ "completions/mean_length": 100.0,
103
+ "completions/mean_terminated_length": 0.0,
104
+ "completions/min_length": 100.0,
105
+ "completions/min_terminated_length": 0.0,
106
+ "entropy": 1.246315559744835,
107
+ "epoch": 0.04,
108
+ "frac_reward_zero_std": 1.0,
109
+ "grad_norm": 0.0,
110
+ "learning_rate": 4.0250000000000004e-06,
111
+ "loss": 0.0,
112
+ "num_tokens": 141264.0,
113
+ "reward": 0.30000001192092896,
114
+ "reward_std": 0.0,
115
+ "rewards/compute_reward/mean": 0.30000001192092896,
116
+ "rewards/compute_reward/std": 0.0,
117
+ "step": 40,
118
+ "step_time": 15.195196880902222
119
+ },
120
+ {
121
+ "clip_ratio/high_max": 0.0,
122
+ "clip_ratio/high_mean": 0.0,
123
+ "clip_ratio/low_mean": 0.0,
124
+ "clip_ratio/low_min": 0.0,
125
+ "clip_ratio/region_mean": 0.0,
126
+ "completions/clipped_ratio": 1.0,
127
+ "completions/max_length": 100.0,
128
+ "completions/max_terminated_length": 0.0,
129
+ "completions/mean_length": 100.0,
130
+ "completions/mean_terminated_length": 0.0,
131
+ "completions/min_length": 100.0,
132
+ "completions/min_terminated_length": 0.0,
133
+ "entropy": 0.7081560462713241,
134
+ "epoch": 0.05,
135
+ "frac_reward_zero_std": 1.0,
136
+ "grad_norm": 0.0,
137
+ "learning_rate": 3.7750000000000003e-06,
138
+ "loss": 0.0,
139
+ "num_tokens": 176780.0,
140
+ "reward": 0.30000001192092896,
141
+ "reward_std": 0.0,
142
+ "rewards/compute_reward/mean": 0.30000001192092896,
143
+ "rewards/compute_reward/std": 0.0,
144
+ "step": 50,
145
+ "step_time": 15.140776808797819
146
+ },
147
+ {
148
+ "clip_ratio/high_max": 0.0,
149
+ "clip_ratio/high_mean": 0.0,
150
+ "clip_ratio/low_mean": 0.0,
151
+ "clip_ratio/low_min": 0.0,
152
+ "clip_ratio/region_mean": 0.0,
153
+ "completions/clipped_ratio": 1.0,
154
+ "completions/max_length": 100.0,
155
+ "completions/max_terminated_length": 0.0,
156
+ "completions/mean_length": 100.0,
157
+ "completions/mean_terminated_length": 0.0,
158
+ "completions/min_length": 100.0,
159
+ "completions/min_terminated_length": 0.0,
160
+ "entropy": 0.727844113111496,
161
+ "epoch": 0.06,
162
+ "frac_reward_zero_std": 1.0,
163
+ "grad_norm": 0.0,
164
+ "learning_rate": 3.525e-06,
165
+ "loss": 0.0,
166
+ "num_tokens": 212628.0,
167
+ "reward": 0.30000001192092896,
168
+ "reward_std": 0.0,
169
+ "rewards/compute_reward/mean": 0.30000001192092896,
170
+ "rewards/compute_reward/std": 0.0,
171
+ "step": 60,
172
+ "step_time": 15.286061269601486
173
+ },
174
+ {
175
+ "clip_ratio/high_max": 0.0,
176
+ "clip_ratio/high_mean": 0.0,
177
+ "clip_ratio/low_mean": 0.0,
178
+ "clip_ratio/low_min": 0.0,
179
+ "clip_ratio/region_mean": 0.0,
180
+ "completions/clipped_ratio": 1.0,
181
+ "completions/max_length": 100.0,
182
+ "completions/max_terminated_length": 0.0,
183
+ "completions/mean_length": 100.0,
184
+ "completions/mean_terminated_length": 0.0,
185
+ "completions/min_length": 100.0,
186
+ "completions/min_terminated_length": 0.0,
187
+ "entropy": 0.7312307402491569,
188
+ "epoch": 0.07,
189
+ "frac_reward_zero_std": 1.0,
190
+ "grad_norm": 0.0,
191
+ "learning_rate": 3.2750000000000004e-06,
192
+ "loss": 0.0,
193
+ "num_tokens": 248212.0,
194
+ "reward": 0.30000001192092896,
195
+ "reward_std": 0.0,
196
+ "rewards/compute_reward/mean": 0.30000001192092896,
197
+ "rewards/compute_reward/std": 0.0,
198
+ "step": 70,
199
+ "step_time": 15.278303197700733
200
+ },
201
+ {
202
+ "clip_ratio/high_max": 0.0,
203
+ "clip_ratio/high_mean": 0.0,
204
+ "clip_ratio/low_mean": 0.0,
205
+ "clip_ratio/low_min": 0.0,
206
+ "clip_ratio/region_mean": 0.0,
207
+ "completions/clipped_ratio": 1.0,
208
+ "completions/max_length": 100.0,
209
+ "completions/max_terminated_length": 0.0,
210
+ "completions/mean_length": 100.0,
211
+ "completions/mean_terminated_length": 0.0,
212
+ "completions/min_length": 100.0,
213
+ "completions/min_terminated_length": 0.0,
214
+ "entropy": 0.7322262570261955,
215
+ "epoch": 0.08,
216
+ "frac_reward_zero_std": 1.0,
217
+ "grad_norm": 0.0,
218
+ "learning_rate": 3.0250000000000003e-06,
219
+ "loss": 0.0,
220
+ "num_tokens": 283644.0,
221
+ "reward": 0.30000001192092896,
222
+ "reward_std": 0.0,
223
+ "rewards/compute_reward/mean": 0.30000001192092896,
224
+ "rewards/compute_reward/std": 0.0,
225
+ "step": 80,
226
+ "step_time": 15.146252356799959
227
+ },
228
+ {
229
+ "clip_ratio/high_max": 0.0,
230
+ "clip_ratio/high_mean": 0.0,
231
+ "clip_ratio/low_mean": 0.0,
232
+ "clip_ratio/low_min": 0.0,
233
+ "clip_ratio/region_mean": 0.0,
234
+ "completions/clipped_ratio": 1.0,
235
+ "completions/max_length": 100.0,
236
+ "completions/max_terminated_length": 0.0,
237
+ "completions/mean_length": 100.0,
238
+ "completions/mean_terminated_length": 0.0,
239
+ "completions/min_length": 100.0,
240
+ "completions/min_terminated_length": 0.0,
241
+ "entropy": 0.7361132100224494,
242
+ "epoch": 0.09,
243
+ "frac_reward_zero_std": 1.0,
244
+ "grad_norm": 0.0,
245
+ "learning_rate": 2.7750000000000005e-06,
246
+ "loss": 0.0,
247
+ "num_tokens": 318532.0,
248
+ "reward": 0.30000001192092896,
249
+ "reward_std": 0.0,
250
+ "rewards/compute_reward/mean": 0.30000001192092896,
251
+ "rewards/compute_reward/std": 0.0,
252
+ "step": 90,
253
+ "step_time": 15.026733554197563
254
+ },
255
+ {
256
+ "clip_ratio/high_max": 0.0,
257
+ "clip_ratio/high_mean": 0.0,
258
+ "clip_ratio/low_mean": 0.0,
259
+ "clip_ratio/low_min": 0.0,
260
+ "clip_ratio/region_mean": 0.0,
261
+ "completions/clipped_ratio": 1.0,
262
+ "completions/max_length": 100.0,
263
+ "completions/max_terminated_length": 0.0,
264
+ "completions/mean_length": 100.0,
265
+ "completions/mean_terminated_length": 0.0,
266
+ "completions/min_length": 100.0,
267
+ "completions/min_terminated_length": 0.0,
268
+ "entropy": 0.7636664807796478,
269
+ "epoch": 0.1,
270
+ "frac_reward_zero_std": 1.0,
271
+ "grad_norm": 0.0,
272
+ "learning_rate": 2.5250000000000004e-06,
273
+ "loss": 0.0,
274
+ "num_tokens": 355352.0,
275
+ "reward": 0.30000001192092896,
276
+ "reward_std": 0.0,
277
+ "rewards/compute_reward/mean": 0.30000001192092896,
278
+ "rewards/compute_reward/std": 0.0,
279
+ "step": 100,
280
+ "step_time": 15.381215008600702
281
+ }
282
+ ],
283
+ "logging_steps": 10,
284
+ "max_steps": 200,
285
+ "num_input_tokens_seen": 355352,
286
+ "num_train_epochs": 1,
287
+ "save_steps": 100,
288
+ "stateful_callbacks": {
289
+ "TrainerControl": {
290
+ "args": {
291
+ "should_epoch_stop": false,
292
+ "should_evaluate": false,
293
+ "should_log": false,
294
+ "should_save": true,
295
+ "should_training_stop": false
296
+ },
297
+ "attributes": {}
298
+ }
299
+ },
300
+ "total_flos": 0.0,
301
+ "train_batch_size": 2,
302
+ "trial_name": null,
303
+ "trial_params": null
304
+ }
training/checkpoints/phase2_final/checkpoint-200/chat_template.jinja ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}
2
+ {% if message['role'] == 'user' %}
3
+ {{ '<|user|>
4
+ ' + message['content'] + eos_token }}
5
+ {% elif message['role'] == 'system' %}
6
+ {{ '<|system|>
7
+ ' + message['content'] + eos_token }}
8
+ {% elif message['role'] == 'assistant' %}
9
+ {{ '<|assistant|>
10
+ ' + message['content'] + eos_token }}
11
+ {% endif %}
12
+ {% if loop.last and add_generation_prompt %}
13
+ {{ '<|assistant|>' }}
14
+ {% endif %}
15
+ {% endfor %}
training/checkpoints/phase2_final/checkpoint-200/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5632,
15
+ "max_position_embeddings": 2048,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 22,
20
+ "num_key_value_heads": 4,
21
+ "pad_token_id": 2,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_parameters": {
25
+ "rope_theta": 10000.0,
26
+ "rope_type": "default"
27
+ },
28
+ "tie_word_embeddings": false,
29
+ "transformers_version": "5.3.0",
30
+ "use_cache": false,
31
+ "vocab_size": 32000
32
+ }
training/checkpoints/phase2_final/checkpoint-200/generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": [
4
+ 2
5
+ ],
6
+ "max_length": 2048,
7
+ "pad_token_id": 2,
8
+ "transformers_version": "5.3.0"
9
+ }
training/checkpoints/phase2_final/checkpoint-200/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
training/checkpoints/phase2_final/checkpoint-200/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "</s>",
7
+ "is_local": true,
8
+ "max_length": null,
9
+ "model_max_length": 2048,
10
+ "pad_to_multiple_of": null,
11
+ "pad_token": "</s>",
12
+ "pad_token_type_id": 0,
13
+ "padding_side": "left",
14
+ "sp_model_kwargs": {},
15
+ "tokenizer_class": "LlamaTokenizer",
16
+ "truncation_side": "left",
17
+ "unk_token": "<unk>",
18
+ "use_default_system_prompt": false
19
+ }
training/checkpoints/phase2_final/checkpoint-200/trainer_state.json ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.2,
6
+ "eval_steps": 500,
7
+ "global_step": 200,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "clip_ratio/high_max": 0.0,
14
+ "clip_ratio/high_mean": 0.0,
15
+ "clip_ratio/low_mean": 0.0,
16
+ "clip_ratio/low_min": 0.0,
17
+ "clip_ratio/region_mean": 0.0,
18
+ "completions/clipped_ratio": 1.0,
19
+ "completions/max_length": 100.0,
20
+ "completions/max_terminated_length": 0.0,
21
+ "completions/mean_length": 100.0,
22
+ "completions/mean_terminated_length": 0.0,
23
+ "completions/min_length": 100.0,
24
+ "completions/min_terminated_length": 0.0,
25
+ "entropy": 0.7924717187881469,
26
+ "epoch": 0.01,
27
+ "frac_reward_zero_std": 0.45,
28
+ "grad_norm": 1.5374246835708618,
29
+ "learning_rate": 4.775e-06,
30
+ "loss": 1.4901161193847657e-09,
31
+ "num_tokens": 35664.0,
32
+ "reward": 0.11875000391155481,
33
+ "reward_std": 0.09771842509508133,
34
+ "rewards/compute_reward/mean": 0.11875000391155481,
35
+ "rewards/compute_reward/std": 0.09771843403577804,
36
+ "step": 10,
37
+ "step_time": 15.109664801302278
38
+ },
39
+ {
40
+ "clip_ratio/high_max": 0.0,
41
+ "clip_ratio/high_mean": 0.0,
42
+ "clip_ratio/low_mean": 0.0,
43
+ "clip_ratio/low_min": 0.0,
44
+ "clip_ratio/region_mean": 0.0,
45
+ "completions/clipped_ratio": 1.0,
46
+ "completions/max_length": 100.0,
47
+ "completions/max_terminated_length": 0.0,
48
+ "completions/mean_length": 100.0,
49
+ "completions/mean_terminated_length": 0.0,
50
+ "completions/min_length": 100.0,
51
+ "completions/min_terminated_length": 0.0,
52
+ "entropy": 0.8351163290441036,
53
+ "epoch": 0.02,
54
+ "frac_reward_zero_std": 0.65,
55
+ "grad_norm": 0.0,
56
+ "learning_rate": 4.525000000000001e-06,
57
+ "loss": 2.6822090148925782e-08,
58
+ "num_tokens": 70060.0,
59
+ "reward": 0.15750000774860382,
60
+ "reward_std": 0.04840061739087105,
61
+ "rewards/compute_reward/mean": 0.15750000774860382,
62
+ "rewards/compute_reward/std": 0.04840061739087105,
63
+ "step": 20,
64
+ "step_time": 14.928892047195404
65
+ },
66
+ {
67
+ "clip_ratio/high_max": 0.0,
68
+ "clip_ratio/high_mean": 0.0,
69
+ "clip_ratio/low_mean": 0.0,
70
+ "clip_ratio/low_min": 0.0,
71
+ "clip_ratio/region_mean": 0.0,
72
+ "completions/clipped_ratio": 1.0,
73
+ "completions/max_length": 100.0,
74
+ "completions/max_terminated_length": 0.0,
75
+ "completions/mean_length": 100.0,
76
+ "completions/mean_terminated_length": 0.0,
77
+ "completions/min_length": 100.0,
78
+ "completions/min_terminated_length": 0.0,
79
+ "entropy": 0.41533662043511865,
80
+ "epoch": 0.03,
81
+ "frac_reward_zero_std": 0.8,
82
+ "grad_norm": 0.0,
83
+ "learning_rate": 4.2750000000000006e-06,
84
+ "loss": 1.4901161193847657e-09,
85
+ "num_tokens": 105588.0,
86
+ "reward": 0.06375000178813935,
87
+ "reward_std": 0.04330107718706131,
88
+ "rewards/compute_reward/mean": 0.06375000178813935,
89
+ "rewards/compute_reward/std": 0.04330108165740967,
90
+ "step": 30,
91
+ "step_time": 15.109792457801813
92
+ },
93
+ {
94
+ "clip_ratio/high_max": 0.0,
95
+ "clip_ratio/high_mean": 0.0,
96
+ "clip_ratio/low_mean": 0.0,
97
+ "clip_ratio/low_min": 0.0,
98
+ "clip_ratio/region_mean": 0.0,
99
+ "completions/clipped_ratio": 1.0,
100
+ "completions/max_length": 100.0,
101
+ "completions/max_terminated_length": 0.0,
102
+ "completions/mean_length": 100.0,
103
+ "completions/mean_terminated_length": 0.0,
104
+ "completions/min_length": 100.0,
105
+ "completions/min_terminated_length": 0.0,
106
+ "entropy": 1.246315559744835,
107
+ "epoch": 0.04,
108
+ "frac_reward_zero_std": 1.0,
109
+ "grad_norm": 0.0,
110
+ "learning_rate": 4.0250000000000004e-06,
111
+ "loss": 0.0,
112
+ "num_tokens": 141264.0,
113
+ "reward": 0.30000001192092896,
114
+ "reward_std": 0.0,
115
+ "rewards/compute_reward/mean": 0.30000001192092896,
116
+ "rewards/compute_reward/std": 0.0,
117
+ "step": 40,
118
+ "step_time": 15.195196880902222
119
+ },
120
+ {
121
+ "clip_ratio/high_max": 0.0,
122
+ "clip_ratio/high_mean": 0.0,
123
+ "clip_ratio/low_mean": 0.0,
124
+ "clip_ratio/low_min": 0.0,
125
+ "clip_ratio/region_mean": 0.0,
126
+ "completions/clipped_ratio": 1.0,
127
+ "completions/max_length": 100.0,
128
+ "completions/max_terminated_length": 0.0,
129
+ "completions/mean_length": 100.0,
130
+ "completions/mean_terminated_length": 0.0,
131
+ "completions/min_length": 100.0,
132
+ "completions/min_terminated_length": 0.0,
133
+ "entropy": 0.7081560462713241,
134
+ "epoch": 0.05,
135
+ "frac_reward_zero_std": 1.0,
136
+ "grad_norm": 0.0,
137
+ "learning_rate": 3.7750000000000003e-06,
138
+ "loss": 0.0,
139
+ "num_tokens": 176780.0,
140
+ "reward": 0.30000001192092896,
141
+ "reward_std": 0.0,
142
+ "rewards/compute_reward/mean": 0.30000001192092896,
143
+ "rewards/compute_reward/std": 0.0,
144
+ "step": 50,
145
+ "step_time": 15.140776808797819
146
+ },
147
+ {
148
+ "clip_ratio/high_max": 0.0,
149
+ "clip_ratio/high_mean": 0.0,
150
+ "clip_ratio/low_mean": 0.0,
151
+ "clip_ratio/low_min": 0.0,
152
+ "clip_ratio/region_mean": 0.0,
153
+ "completions/clipped_ratio": 1.0,
154
+ "completions/max_length": 100.0,
155
+ "completions/max_terminated_length": 0.0,
156
+ "completions/mean_length": 100.0,
157
+ "completions/mean_terminated_length": 0.0,
158
+ "completions/min_length": 100.0,
159
+ "completions/min_terminated_length": 0.0,
160
+ "entropy": 0.727844113111496,
161
+ "epoch": 0.06,
162
+ "frac_reward_zero_std": 1.0,
163
+ "grad_norm": 0.0,
164
+ "learning_rate": 3.525e-06,
165
+ "loss": 0.0,
166
+ "num_tokens": 212628.0,
167
+ "reward": 0.30000001192092896,
168
+ "reward_std": 0.0,
169
+ "rewards/compute_reward/mean": 0.30000001192092896,
170
+ "rewards/compute_reward/std": 0.0,
171
+ "step": 60,
172
+ "step_time": 15.286061269601486
173
+ },
174
+ {
175
+ "clip_ratio/high_max": 0.0,
176
+ "clip_ratio/high_mean": 0.0,
177
+ "clip_ratio/low_mean": 0.0,
178
+ "clip_ratio/low_min": 0.0,
179
+ "clip_ratio/region_mean": 0.0,
180
+ "completions/clipped_ratio": 1.0,
181
+ "completions/max_length": 100.0,
182
+ "completions/max_terminated_length": 0.0,
183
+ "completions/mean_length": 100.0,
184
+ "completions/mean_terminated_length": 0.0,
185
+ "completions/min_length": 100.0,
186
+ "completions/min_terminated_length": 0.0,
187
+ "entropy": 0.7312307402491569,
188
+ "epoch": 0.07,
189
+ "frac_reward_zero_std": 1.0,
190
+ "grad_norm": 0.0,
191
+ "learning_rate": 3.2750000000000004e-06,
192
+ "loss": 0.0,
193
+ "num_tokens": 248212.0,
194
+ "reward": 0.30000001192092896,
195
+ "reward_std": 0.0,
196
+ "rewards/compute_reward/mean": 0.30000001192092896,
197
+ "rewards/compute_reward/std": 0.0,
198
+ "step": 70,
199
+ "step_time": 15.278303197700733
200
+ },
201
+ {
202
+ "clip_ratio/high_max": 0.0,
203
+ "clip_ratio/high_mean": 0.0,
204
+ "clip_ratio/low_mean": 0.0,
205
+ "clip_ratio/low_min": 0.0,
206
+ "clip_ratio/region_mean": 0.0,
207
+ "completions/clipped_ratio": 1.0,
208
+ "completions/max_length": 100.0,
209
+ "completions/max_terminated_length": 0.0,
210
+ "completions/mean_length": 100.0,
211
+ "completions/mean_terminated_length": 0.0,
212
+ "completions/min_length": 100.0,
213
+ "completions/min_terminated_length": 0.0,
214
+ "entropy": 0.7322262570261955,
215
+ "epoch": 0.08,
216
+ "frac_reward_zero_std": 1.0,
217
+ "grad_norm": 0.0,
218
+ "learning_rate": 3.0250000000000003e-06,
219
+ "loss": 0.0,
220
+ "num_tokens": 283644.0,
221
+ "reward": 0.30000001192092896,
222
+ "reward_std": 0.0,
223
+ "rewards/compute_reward/mean": 0.30000001192092896,
224
+ "rewards/compute_reward/std": 0.0,
225
+ "step": 80,
226
+ "step_time": 15.146252356799959
227
+ },
228
+ {
229
+ "clip_ratio/high_max": 0.0,
230
+ "clip_ratio/high_mean": 0.0,
231
+ "clip_ratio/low_mean": 0.0,
232
+ "clip_ratio/low_min": 0.0,
233
+ "clip_ratio/region_mean": 0.0,
234
+ "completions/clipped_ratio": 1.0,
235
+ "completions/max_length": 100.0,
236
+ "completions/max_terminated_length": 0.0,
237
+ "completions/mean_length": 100.0,
238
+ "completions/mean_terminated_length": 0.0,
239
+ "completions/min_length": 100.0,
240
+ "completions/min_terminated_length": 0.0,
241
+ "entropy": 0.7361132100224494,
242
+ "epoch": 0.09,
243
+ "frac_reward_zero_std": 1.0,
244
+ "grad_norm": 0.0,
245
+ "learning_rate": 2.7750000000000005e-06,
246
+ "loss": 0.0,
247
+ "num_tokens": 318532.0,
248
+ "reward": 0.30000001192092896,
249
+ "reward_std": 0.0,
250
+ "rewards/compute_reward/mean": 0.30000001192092896,
251
+ "rewards/compute_reward/std": 0.0,
252
+ "step": 90,
253
+ "step_time": 15.026733554197563
254
+ },
255
+ {
256
+ "clip_ratio/high_max": 0.0,
257
+ "clip_ratio/high_mean": 0.0,
258
+ "clip_ratio/low_mean": 0.0,
259
+ "clip_ratio/low_min": 0.0,
260
+ "clip_ratio/region_mean": 0.0,
261
+ "completions/clipped_ratio": 1.0,
262
+ "completions/max_length": 100.0,
263
+ "completions/max_terminated_length": 0.0,
264
+ "completions/mean_length": 100.0,
265
+ "completions/mean_terminated_length": 0.0,
266
+ "completions/min_length": 100.0,
267
+ "completions/min_terminated_length": 0.0,
268
+ "entropy": 0.7636664807796478,
269
+ "epoch": 0.1,
270
+ "frac_reward_zero_std": 1.0,
271
+ "grad_norm": 0.0,
272
+ "learning_rate": 2.5250000000000004e-06,
273
+ "loss": 0.0,
274
+ "num_tokens": 355352.0,
275
+ "reward": 0.30000001192092896,
276
+ "reward_std": 0.0,
277
+ "rewards/compute_reward/mean": 0.30000001192092896,
278
+ "rewards/compute_reward/std": 0.0,
279
+ "step": 100,
280
+ "step_time": 15.381215008600702
281
+ },
282
+ {
283
+ "clip_ratio/high_max": 0.0,
284
+ "clip_ratio/high_mean": 0.0,
285
+ "clip_ratio/low_mean": 0.0,
286
+ "clip_ratio/low_min": 0.0,
287
+ "clip_ratio/region_mean": 0.0,
288
+ "completions/clipped_ratio": 1.0,
289
+ "completions/max_length": 100.0,
290
+ "completions/max_terminated_length": 0.0,
291
+ "completions/mean_length": 100.0,
292
+ "completions/mean_terminated_length": 0.0,
293
+ "completions/min_length": 100.0,
294
+ "completions/min_terminated_length": 0.0,
295
+ "entropy": 0.7429351836442948,
296
+ "epoch": 0.11,
297
+ "frac_reward_zero_std": 1.0,
298
+ "grad_norm": 0.0,
299
+ "learning_rate": 2.2750000000000002e-06,
300
+ "loss": 0.0,
301
+ "num_tokens": 389508.0,
302
+ "reward": 0.30000001192092896,
303
+ "reward_std": 0.0,
304
+ "rewards/compute_reward/mean": 0.30000001192092896,
305
+ "rewards/compute_reward/std": 0.0,
306
+ "step": 110,
307
+ "step_time": 15.039604106301704
308
+ },
309
+ {
310
+ "clip_ratio/high_max": 0.0,
311
+ "clip_ratio/high_mean": 0.0,
312
+ "clip_ratio/low_mean": 0.0,
313
+ "clip_ratio/low_min": 0.0,
314
+ "clip_ratio/region_mean": 0.0,
315
+ "completions/clipped_ratio": 1.0,
316
+ "completions/max_length": 100.0,
317
+ "completions/max_terminated_length": 0.0,
318
+ "completions/mean_length": 100.0,
319
+ "completions/mean_terminated_length": 0.0,
320
+ "completions/min_length": 100.0,
321
+ "completions/min_terminated_length": 0.0,
322
+ "entropy": 0.7703481003642082,
323
+ "epoch": 0.12,
324
+ "frac_reward_zero_std": 1.0,
325
+ "grad_norm": 0.0,
326
+ "learning_rate": 2.025e-06,
327
+ "loss": 0.0,
328
+ "num_tokens": 426240.0,
329
+ "reward": 0.30000001192092896,
330
+ "reward_std": 0.0,
331
+ "rewards/compute_reward/mean": 0.30000001192092896,
332
+ "rewards/compute_reward/std": 0.0,
333
+ "step": 120,
334
+ "step_time": 15.29271342299835
335
+ },
336
+ {
337
+ "clip_ratio/high_max": 0.0,
338
+ "clip_ratio/high_mean": 0.0,
339
+ "clip_ratio/low_mean": 0.0,
340
+ "clip_ratio/low_min": 0.0,
341
+ "clip_ratio/region_mean": 0.0,
342
+ "completions/clipped_ratio": 1.0,
343
+ "completions/max_length": 100.0,
344
+ "completions/max_terminated_length": 0.0,
345
+ "completions/mean_length": 100.0,
346
+ "completions/mean_terminated_length": 0.0,
347
+ "completions/min_length": 100.0,
348
+ "completions/min_terminated_length": 0.0,
349
+ "entropy": 0.7375139251351357,
350
+ "epoch": 0.13,
351
+ "frac_reward_zero_std": 1.0,
352
+ "grad_norm": 0.0,
353
+ "learning_rate": 1.7750000000000002e-06,
354
+ "loss": 0.0,
355
+ "num_tokens": 462400.0,
356
+ "reward": 0.30000001192092896,
357
+ "reward_std": 0.0,
358
+ "rewards/compute_reward/mean": 0.30000001192092896,
359
+ "rewards/compute_reward/std": 0.0,
360
+ "step": 130,
361
+ "step_time": 15.20639470120077
362
+ },
363
+ {
364
+ "clip_ratio/high_max": 0.0,
365
+ "clip_ratio/high_mean": 0.0,
366
+ "clip_ratio/low_mean": 0.0,
367
+ "clip_ratio/low_min": 0.0,
368
+ "clip_ratio/region_mean": 0.0,
369
+ "completions/clipped_ratio": 1.0,
370
+ "completions/max_length": 100.0,
371
+ "completions/max_terminated_length": 0.0,
372
+ "completions/mean_length": 100.0,
373
+ "completions/mean_terminated_length": 0.0,
374
+ "completions/min_length": 100.0,
375
+ "completions/min_terminated_length": 0.0,
376
+ "entropy": 0.8568216070532799,
377
+ "epoch": 0.14,
378
+ "frac_reward_zero_std": 0.9,
379
+ "grad_norm": 0.0,
380
+ "learning_rate": 1.525e-06,
381
+ "loss": 1.7881393432617187e-08,
382
+ "num_tokens": 498020.0,
383
+ "reward": 0.30500001311302183,
384
+ "reward_std": 0.01414213478565216,
385
+ "rewards/compute_reward/mean": 0.30500001311302183,
386
+ "rewards/compute_reward/std": 0.01414213478565216,
387
+ "step": 140,
388
+ "step_time": 15.200056954801402
389
+ },
390
+ {
391
+ "clip_ratio/high_max": 0.0,
392
+ "clip_ratio/high_mean": 0.0,
393
+ "clip_ratio/low_mean": 0.0,
394
+ "clip_ratio/low_min": 0.0,
395
+ "clip_ratio/region_mean": 0.0,
396
+ "completions/clipped_ratio": 1.0,
397
+ "completions/max_length": 100.0,
398
+ "completions/max_terminated_length": 0.0,
399
+ "completions/mean_length": 100.0,
400
+ "completions/mean_terminated_length": 0.0,
401
+ "completions/min_length": 100.0,
402
+ "completions/min_terminated_length": 0.0,
403
+ "entropy": 1.4760520339012146,
404
+ "epoch": 0.15,
405
+ "frac_reward_zero_std": 0.95,
406
+ "grad_norm": 0.0,
407
+ "learning_rate": 1.275e-06,
408
+ "loss": 8.940696716308593e-09,
409
+ "num_tokens": 532668.0,
410
+ "reward": 0.3025000125169754,
411
+ "reward_std": 0.00707106739282608,
412
+ "rewards/compute_reward/mean": 0.3025000125169754,
413
+ "rewards/compute_reward/std": 0.00707106739282608,
414
+ "step": 150,
415
+ "step_time": 14.748404727898015
416
+ },
417
+ {
418
+ "clip_ratio/high_max": 0.0,
419
+ "clip_ratio/high_mean": 0.0,
420
+ "clip_ratio/low_mean": 0.0,
421
+ "clip_ratio/low_min": 0.0,
422
+ "clip_ratio/region_mean": 0.0,
423
+ "completions/clipped_ratio": 1.0,
424
+ "completions/max_length": 100.0,
425
+ "completions/max_terminated_length": 0.0,
426
+ "completions/mean_length": 100.0,
427
+ "completions/mean_terminated_length": 0.0,
428
+ "completions/min_length": 100.0,
429
+ "completions/min_terminated_length": 0.0,
430
+ "entropy": 1.7379814833402634,
431
+ "epoch": 0.16,
432
+ "frac_reward_zero_std": 0.9,
433
+ "grad_norm": 0.0,
434
+ "learning_rate": 1.025e-06,
435
+ "loss": 1.564621925354004e-08,
436
+ "num_tokens": 567544.0,
437
+ "reward": 0.3037500113248825,
438
+ "reward_std": 0.01060660146176815,
439
+ "rewards/compute_reward/mean": 0.3037500113248825,
440
+ "rewards/compute_reward/std": 0.01060660108923912,
441
+ "step": 160,
442
+ "step_time": 15.037257523898734
443
+ },
444
+ {
445
+ "clip_ratio/high_max": 0.0,
446
+ "clip_ratio/high_mean": 0.0,
447
+ "clip_ratio/low_mean": 0.0,
448
+ "clip_ratio/low_min": 0.0,
449
+ "clip_ratio/region_mean": 0.0,
450
+ "completions/clipped_ratio": 1.0,
451
+ "completions/max_length": 100.0,
452
+ "completions/max_terminated_length": 0.0,
453
+ "completions/mean_length": 100.0,
454
+ "completions/mean_terminated_length": 0.0,
455
+ "completions/min_length": 100.0,
456
+ "completions/min_terminated_length": 0.0,
457
+ "entropy": 1.5534777998924256,
458
+ "epoch": 0.17,
459
+ "frac_reward_zero_std": 0.8,
460
+ "grad_norm": 0.0,
461
+ "learning_rate": 7.750000000000001e-07,
462
+ "loss": 1.7881393432617187e-08,
463
+ "num_tokens": 604400.0,
464
+ "reward": 0.31500001549720763,
465
+ "reward_std": 0.032658536732196805,
466
+ "rewards/compute_reward/mean": 0.31500001549720763,
467
+ "rewards/compute_reward/std": 0.032658536732196805,
468
+ "step": 170,
469
+ "step_time": 15.339705387198773
470
+ },
471
+ {
472
+ "clip_ratio/high_max": 0.0,
473
+ "clip_ratio/high_mean": 0.0,
474
+ "clip_ratio/low_mean": 0.0,
475
+ "clip_ratio/low_min": 0.0,
476
+ "clip_ratio/region_mean": 0.0,
477
+ "completions/clipped_ratio": 1.0,
478
+ "completions/max_length": 100.0,
479
+ "completions/max_terminated_length": 0.0,
480
+ "completions/mean_length": 100.0,
481
+ "completions/mean_terminated_length": 0.0,
482
+ "completions/min_length": 100.0,
483
+ "completions/min_terminated_length": 0.0,
484
+ "entropy": 1.3570319384336471,
485
+ "epoch": 0.18,
486
+ "frac_reward_zero_std": 0.9,
487
+ "grad_norm": 2.3024227619171143,
488
+ "learning_rate": 5.250000000000001e-07,
489
+ "loss": 8.195638656616212e-09,
490
+ "num_tokens": 639432.0,
491
+ "reward": 0.3075000137090683,
492
+ "reward_std": 0.02121320217847824,
493
+ "rewards/compute_reward/mean": 0.3075000137090683,
494
+ "rewards/compute_reward/std": 0.02121320217847824,
495
+ "step": 180,
496
+ "step_time": 14.838772397398861
497
+ },
498
+ {
499
+ "clip_ratio/high_max": 0.0,
500
+ "clip_ratio/high_mean": 0.0,
501
+ "clip_ratio/low_mean": 0.0,
502
+ "clip_ratio/low_min": 0.0,
503
+ "clip_ratio/region_mean": 0.0,
504
+ "completions/clipped_ratio": 1.0,
505
+ "completions/max_length": 100.0,
506
+ "completions/max_terminated_length": 0.0,
507
+ "completions/mean_length": 100.0,
508
+ "completions/mean_terminated_length": 0.0,
509
+ "completions/min_length": 100.0,
510
+ "completions/min_terminated_length": 0.0,
511
+ "entropy": 1.4456530869007111,
512
+ "epoch": 0.19,
513
+ "frac_reward_zero_std": 0.75,
514
+ "grad_norm": 1.3511810302734375,
515
+ "learning_rate": 2.75e-07,
516
+ "loss": 4.0978193283081055e-08,
517
+ "num_tokens": 674972.0,
518
+ "reward": 0.31125001311302186,
519
+ "reward_std": 0.03181980364024639,
520
+ "rewards/compute_reward/mean": 0.31125001311302186,
521
+ "rewards/compute_reward/std": 0.03181980326771736,
522
+ "step": 190,
523
+ "step_time": 15.081224197598932
524
+ },
525
+ {
526
+ "clip_ratio/high_max": 0.0,
527
+ "clip_ratio/high_mean": 0.0,
528
+ "clip_ratio/low_mean": 0.0,
529
+ "clip_ratio/low_min": 0.0,
530
+ "clip_ratio/region_mean": 0.0,
531
+ "completions/clipped_ratio": 1.0,
532
+ "completions/max_length": 100.0,
533
+ "completions/max_terminated_length": 0.0,
534
+ "completions/mean_length": 100.0,
535
+ "completions/mean_terminated_length": 0.0,
536
+ "completions/min_length": 100.0,
537
+ "completions/min_terminated_length": 0.0,
538
+ "entropy": 1.5674545228481294,
539
+ "epoch": 0.2,
540
+ "frac_reward_zero_std": 0.75,
541
+ "grad_norm": 1.818772792816162,
542
+ "learning_rate": 2.5000000000000002e-08,
543
+ "loss": 3.129243850708008e-08,
544
+ "num_tokens": 709480.0,
545
+ "reward": 0.31750001311302184,
546
+ "reward_std": 0.04316474497318268,
547
+ "rewards/compute_reward/mean": 0.31750001311302184,
548
+ "rewards/compute_reward/std": 0.04316474497318268,
549
+ "step": 200,
550
+ "step_time": 15.07085579989798
551
+ }
552
+ ],
553
+ "logging_steps": 10,
554
+ "max_steps": 200,
555
+ "num_input_tokens_seen": 709480,
556
+ "num_train_epochs": 1,
557
+ "save_steps": 100,
558
+ "stateful_callbacks": {
559
+ "TrainerControl": {
560
+ "args": {
561
+ "should_epoch_stop": false,
562
+ "should_evaluate": false,
563
+ "should_log": false,
564
+ "should_save": true,
565
+ "should_training_stop": true
566
+ },
567
+ "attributes": {}
568
+ }
569
+ },
570
+ "total_flos": 0.0,
571
+ "train_batch_size": 2,
572
+ "trial_name": null,
573
+ "trial_params": null
574
+ }
training/checkpoints/phase2_final/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5632,
15
+ "max_position_embeddings": 2048,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 22,
20
+ "num_key_value_heads": 4,
21
+ "pad_token_id": 2,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_parameters": {
25
+ "rope_theta": 10000.0,
26
+ "rope_type": "default"
27
+ },
28
+ "tie_word_embeddings": false,
29
+ "transformers_version": "5.3.0",
30
+ "use_cache": false,
31
+ "vocab_size": 32000
32
+ }
training/checkpoints/phase2_final/generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": [
4
+ 2
5
+ ],
6
+ "max_length": 2048,
7
+ "pad_token_id": 2,
8
+ "transformers_version": "5.3.0"
9
+ }
training/checkpoints/phase2_final/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
training/checkpoints/phase2_final/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "</s>",
7
+ "is_local": true,
8
+ "max_length": null,
9
+ "model_max_length": 2048,
10
+ "pad_to_multiple_of": null,
11
+ "pad_token": "</s>",
12
+ "pad_token_type_id": 0,
13
+ "padding_side": "left",
14
+ "sp_model_kwargs": {},
15
+ "tokenizer_class": "LlamaTokenizer",
16
+ "truncation_side": "left",
17
+ "unk_token": "<unk>",
18
+ "use_default_system_prompt": false
19
+ }
training/checkpoints/unified_final/README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ model_name: unified_final
4
+ tags:
5
+ - generated_from_trainer
6
+ - trl
7
+ - grpo
8
+ licence: license
9
+ ---
10
+
11
+ # Model Card for unified_final
12
+
13
+ This model is a fine-tuned version of [None](https://huggingface.co/None).
14
+ It has been trained using [TRL](https://github.com/huggingface/trl).
15
+
16
+ ## Quick start
17
+
18
+ ```python
19
+ from transformers import pipeline
20
+
21
+ question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
22
+ generator = pipeline("text-generation", model="None", device="cuda")
23
+ output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
24
+ print(output["generated_text"])
25
+ ```
26
+
27
+ ## Training procedure
28
+
29
+
30
+
31
+
32
+
33
+ This model was trained with GRPO, a method introduced in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
34
+
35
+ ### Framework versions
36
+
37
+ - TRL: 0.29.0
38
+ - Transformers: 5.3.0
39
+ - Pytorch: 2.12.0.dev20260307+cu128
40
+ - Datasets: 4.6.1
41
+ - Tokenizers: 0.22.2
42
+
43
+ ## Citations
44
+
45
+ Cite GRPO as:
46
+
47
+ ```bibtex
48
+ @article{shao2024deepseekmath,
49
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
50
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
51
+ year = 2024,
52
+ eprint = {arXiv:2402.03300},
53
+ }
54
+
55
+ ```
56
+
57
+ Cite TRL as:
58
+
59
+ ```bibtex
60
+ @software{vonwerra2020trl,
61
+ title = {{TRL: Transformers Reinforcement Learning}},
62
+ author = {von Werra, Leandro and Belkada, Younes and Tunstall, Lewis and Beeching, Edward and Thrush, Tristan and Lambert, Nathan and Huang, Shengyi and Rasul, Kashif and Gallouédec, Quentin},
63
+ license = {Apache-2.0},
64
+ url = {https://github.com/huggingface/trl},
65
+ year = {2020}
66
+ }
67
+ ```
training/checkpoints/unified_final/chat_template.jinja ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}
2
+ {% if message['role'] == 'user' %}
3
+ {{ '<|user|>
4
+ ' + message['content'] + eos_token }}
5
+ {% elif message['role'] == 'system' %}
6
+ {{ '<|system|>
7
+ ' + message['content'] + eos_token }}
8
+ {% elif message['role'] == 'assistant' %}
9
+ {{ '<|assistant|>
10
+ ' + message['content'] + eos_token }}
11
+ {% endif %}
12
+ {% if loop.last and add_generation_prompt %}
13
+ {{ '<|assistant|>' }}
14
+ {% endif %}
15
+ {% endfor %}
training/checkpoints/unified_final/checkpoint-100/chat_template.jinja ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}
2
+ {% if message['role'] == 'user' %}
3
+ {{ '<|user|>
4
+ ' + message['content'] + eos_token }}
5
+ {% elif message['role'] == 'system' %}
6
+ {{ '<|system|>
7
+ ' + message['content'] + eos_token }}
8
+ {% elif message['role'] == 'assistant' %}
9
+ {{ '<|assistant|>
10
+ ' + message['content'] + eos_token }}
11
+ {% endif %}
12
+ {% if loop.last and add_generation_prompt %}
13
+ {{ '<|assistant|>' }}
14
+ {% endif %}
15
+ {% endfor %}
training/checkpoints/unified_final/checkpoint-100/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5632,
15
+ "max_position_embeddings": 2048,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 22,
20
+ "num_key_value_heads": 4,
21
+ "pad_token_id": 2,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_parameters": {
25
+ "rope_theta": 10000.0,
26
+ "rope_type": "default"
27
+ },
28
+ "tie_word_embeddings": false,
29
+ "transformers_version": "5.3.0",
30
+ "use_cache": false,
31
+ "vocab_size": 32000
32
+ }
training/checkpoints/unified_final/checkpoint-100/generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": [
4
+ 2
5
+ ],
6
+ "max_length": 2048,
7
+ "pad_token_id": 2,
8
+ "transformers_version": "5.3.0"
9
+ }
training/checkpoints/unified_final/checkpoint-100/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
training/checkpoints/unified_final/checkpoint-100/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "</s>",
7
+ "is_local": true,
8
+ "max_length": null,
9
+ "model_max_length": 2048,
10
+ "pad_to_multiple_of": null,
11
+ "pad_token": "</s>",
12
+ "pad_token_type_id": 0,
13
+ "padding_side": "left",
14
+ "sp_model_kwargs": {},
15
+ "tokenizer_class": "LlamaTokenizer",
16
+ "truncation_side": "left",
17
+ "unk_token": "<unk>",
18
+ "use_default_system_prompt": false
19
+ }
training/checkpoints/unified_final/checkpoint-100/trainer_state.json ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.1,
6
+ "eval_steps": 500,
7
+ "global_step": 100,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "clip_ratio/high_max": 0.0,
14
+ "clip_ratio/high_mean": 0.0,
15
+ "clip_ratio/low_mean": 0.0,
16
+ "clip_ratio/low_min": 0.0,
17
+ "clip_ratio/region_mean": 0.0,
18
+ "completions/clipped_ratio": 1.0,
19
+ "completions/max_length": 100.0,
20
+ "completions/max_terminated_length": 0.0,
21
+ "completions/mean_length": 100.0,
22
+ "completions/mean_terminated_length": 0.0,
23
+ "completions/min_length": 100.0,
24
+ "completions/min_terminated_length": 0.0,
25
+ "entropy": 1.3566992908716202,
26
+ "epoch": 0.01,
27
+ "frac_reward_zero_std": 0.0,
28
+ "grad_norm": 0.7344621419906616,
29
+ "learning_rate": 4.775e-06,
30
+ "loss": 3.0994415283203126e-07,
31
+ "num_tokens": 35800.0,
32
+ "reward": 0.01268580500036478,
33
+ "reward_std": 0.02462496655061841,
34
+ "rewards/compute_reward/mean": 0.01268580500036478,
35
+ "rewards/compute_reward/std": 0.024624967435374855,
36
+ "step": 10,
37
+ "step_time": 10.7134718033034
38
+ },
39
+ {
40
+ "clip_ratio/high_max": 0.0,
41
+ "clip_ratio/high_mean": 0.0,
42
+ "clip_ratio/low_mean": 0.0,
43
+ "clip_ratio/low_min": 0.0,
44
+ "clip_ratio/region_mean": 0.0,
45
+ "completions/clipped_ratio": 1.0,
46
+ "completions/max_length": 100.0,
47
+ "completions/max_terminated_length": 0.0,
48
+ "completions/mean_length": 100.0,
49
+ "completions/mean_terminated_length": 0.0,
50
+ "completions/min_length": 100.0,
51
+ "completions/min_terminated_length": 0.0,
52
+ "entropy": 1.3169427752494811,
53
+ "epoch": 0.02,
54
+ "frac_reward_zero_std": 0.0,
55
+ "grad_norm": 4.441726207733154,
56
+ "learning_rate": 4.525000000000001e-06,
57
+ "loss": -4.246830940246582e-07,
58
+ "num_tokens": 71748.0,
59
+ "reward": -0.04455982223153114,
60
+ "reward_std": 0.035665383422747256,
61
+ "rewards/compute_reward/mean": -0.04455982223153114,
62
+ "rewards/compute_reward/std": 0.03566538490122184,
63
+ "step": 20,
64
+ "step_time": 10.643421414200565
65
+ },
66
+ {
67
+ "clip_ratio/high_max": 0.0,
68
+ "clip_ratio/high_mean": 0.0,
69
+ "clip_ratio/low_mean": 0.0,
70
+ "clip_ratio/low_min": 0.0,
71
+ "clip_ratio/region_mean": 0.0,
72
+ "completions/clipped_ratio": 0.9875,
73
+ "completions/max_length": 100.0,
74
+ "completions/max_terminated_length": 8.9,
75
+ "completions/mean_length": 99.8625,
76
+ "completions/mean_terminated_length": 8.9,
77
+ "completions/min_length": 98.9,
78
+ "completions/min_terminated_length": 8.9,
79
+ "entropy": 1.0057833462953567,
80
+ "epoch": 0.03,
81
+ "frac_reward_zero_std": 0.0,
82
+ "grad_norm": 3.0170326232910156,
83
+ "learning_rate": 4.2750000000000006e-06,
84
+ "loss": -0.0018164031207561493,
85
+ "num_tokens": 108181.0,
86
+ "reward": 0.0374881561845541,
87
+ "reward_std": 0.020618790527805686,
88
+ "rewards/compute_reward/mean": 0.0374881561845541,
89
+ "rewards/compute_reward/std": 0.0206187907140702,
90
+ "step": 30,
91
+ "step_time": 10.756140169796709
92
+ },
93
+ {
94
+ "clip_ratio/high_max": 0.0,
95
+ "clip_ratio/high_mean": 0.0,
96
+ "clip_ratio/low_mean": 0.0,
97
+ "clip_ratio/low_min": 0.0,
98
+ "clip_ratio/region_mean": 0.0,
99
+ "completions/clipped_ratio": 0.9875,
100
+ "completions/max_length": 100.0,
101
+ "completions/max_terminated_length": 6.6,
102
+ "completions/mean_length": 99.575,
103
+ "completions/mean_terminated_length": 6.6,
104
+ "completions/min_length": 96.6,
105
+ "completions/min_terminated_length": 6.6,
106
+ "entropy": 1.7816664546728134,
107
+ "epoch": 0.04,
108
+ "frac_reward_zero_std": 0.0,
109
+ "grad_norm": 5.86561393737793,
110
+ "learning_rate": 4.0250000000000004e-06,
111
+ "loss": -0.006361240148544311,
112
+ "num_tokens": 143375.0,
113
+ "reward": -0.014824284799396991,
114
+ "reward_std": 0.06699581742286682,
115
+ "rewards/compute_reward/mean": -0.014824284799396991,
116
+ "rewards/compute_reward/std": 0.06699582003057003,
117
+ "step": 40,
118
+ "step_time": 10.785410385398427
119
+ },
120
+ {
121
+ "clip_ratio/high_max": 0.0,
122
+ "clip_ratio/high_mean": 0.0,
123
+ "clip_ratio/low_mean": 0.0,
124
+ "clip_ratio/low_min": 0.0,
125
+ "clip_ratio/region_mean": 0.0,
126
+ "completions/clipped_ratio": 0.9875,
127
+ "completions/max_length": 100.0,
128
+ "completions/max_terminated_length": 3.0,
129
+ "completions/mean_length": 99.125,
130
+ "completions/mean_terminated_length": 3.0,
131
+ "completions/min_length": 93.0,
132
+ "completions/min_terminated_length": 3.0,
133
+ "entropy": 2.1307705104351045,
134
+ "epoch": 0.05,
135
+ "frac_reward_zero_std": 0.0,
136
+ "grad_norm": 6.191352367401123,
137
+ "learning_rate": 3.7750000000000003e-06,
138
+ "loss": -0.011027154326438905,
139
+ "num_tokens": 178941.0,
140
+ "reward": -0.016337488451972602,
141
+ "reward_std": 0.051818730868399145,
142
+ "rewards/compute_reward/mean": -0.016337488451972602,
143
+ "rewards/compute_reward/std": 0.05181873142719269,
144
+ "step": 50,
145
+ "step_time": 10.741381045605522
146
+ },
147
+ {
148
+ "clip_ratio/high_max": 0.0,
149
+ "clip_ratio/high_mean": 0.0,
150
+ "clip_ratio/low_mean": 0.0,
151
+ "clip_ratio/low_min": 0.0,
152
+ "clip_ratio/region_mean": 0.0,
153
+ "completions/clipped_ratio": 0.9875,
154
+ "completions/max_length": 100.0,
155
+ "completions/max_terminated_length": 8.8,
156
+ "completions/mean_length": 99.85,
157
+ "completions/mean_terminated_length": 8.8,
158
+ "completions/min_length": 98.8,
159
+ "completions/min_terminated_length": 8.8,
160
+ "entropy": 2.1041357040405275,
161
+ "epoch": 0.06,
162
+ "frac_reward_zero_std": 0.0,
163
+ "grad_norm": 8.536041259765625,
164
+ "learning_rate": 3.525e-06,
165
+ "loss": 0.0019509844481945039,
166
+ "num_tokens": 216257.0,
167
+ "reward": 0.035917540453374384,
168
+ "reward_std": 0.04930563308298588,
169
+ "rewards/compute_reward/mean": 0.035917540453374384,
170
+ "rewards/compute_reward/std": 0.049305635318160054,
171
+ "step": 60,
172
+ "step_time": 11.27133785020269
173
+ },
174
+ {
175
+ "clip_ratio/high_max": 0.0,
176
+ "clip_ratio/high_mean": 0.0,
177
+ "clip_ratio/low_mean": 0.0,
178
+ "clip_ratio/low_min": 0.0,
179
+ "clip_ratio/region_mean": 0.0,
180
+ "completions/clipped_ratio": 0.8,
181
+ "completions/max_length": 100.0,
182
+ "completions/max_terminated_length": 48.2,
183
+ "completions/mean_length": 92.9625,
184
+ "completions/mean_terminated_length": 38.51333351135254,
185
+ "completions/min_length": 70.1,
186
+ "completions/min_terminated_length": 30.1,
187
+ "entropy": 1.6469052851200103,
188
+ "epoch": 0.07,
189
+ "frac_reward_zero_std": 0.0,
190
+ "grad_norm": 6.919373512268066,
191
+ "learning_rate": 3.2750000000000004e-06,
192
+ "loss": -0.02075239419937134,
193
+ "num_tokens": 251110.0,
194
+ "reward": 0.007261525164358318,
195
+ "reward_std": 0.0802696269005537,
196
+ "rewards/compute_reward/mean": 0.007261525164358318,
197
+ "rewards/compute_reward/std": 0.08026962876319885,
198
+ "step": 70,
199
+ "step_time": 10.774873650902009
200
+ },
201
+ {
202
+ "clip_ratio/high_max": 0.0,
203
+ "clip_ratio/high_mean": 0.0,
204
+ "clip_ratio/low_mean": 0.0,
205
+ "clip_ratio/low_min": 0.0,
206
+ "clip_ratio/region_mean": 0.0,
207
+ "completions/clipped_ratio": 0.9875,
208
+ "completions/max_length": 100.0,
209
+ "completions/max_terminated_length": 3.1,
210
+ "completions/mean_length": 99.1375,
211
+ "completions/mean_terminated_length": 3.1,
212
+ "completions/min_length": 93.1,
213
+ "completions/min_terminated_length": 3.1,
214
+ "entropy": 2.2336367428302766,
215
+ "epoch": 0.08,
216
+ "frac_reward_zero_std": 0.0,
217
+ "grad_norm": 4.918172836303711,
218
+ "learning_rate": 3.0250000000000003e-06,
219
+ "loss": 0.008250368386507034,
220
+ "num_tokens": 285729.0,
221
+ "reward": 0.027657157555222512,
222
+ "reward_std": 0.04840414375066757,
223
+ "rewards/compute_reward/mean": 0.027657157555222512,
224
+ "rewards/compute_reward/std": 0.048404145427048205,
225
+ "step": 80,
226
+ "step_time": 10.43483721170196
227
+ },
228
+ {
229
+ "clip_ratio/high_max": 0.0,
230
+ "clip_ratio/high_mean": 0.0,
231
+ "clip_ratio/low_mean": 0.0,
232
+ "clip_ratio/low_min": 0.0,
233
+ "clip_ratio/region_mean": 0.0,
234
+ "completions/clipped_ratio": 1.0,
235
+ "completions/max_length": 100.0,
236
+ "completions/max_terminated_length": 0.0,
237
+ "completions/mean_length": 100.0,
238
+ "completions/mean_terminated_length": 0.0,
239
+ "completions/min_length": 100.0,
240
+ "completions/min_terminated_length": 0.0,
241
+ "entropy": 1.8057245463132858,
242
+ "epoch": 0.09,
243
+ "frac_reward_zero_std": 0.0,
244
+ "grad_norm": 4.417481422424316,
245
+ "learning_rate": 2.7750000000000005e-06,
246
+ "loss": 2.216547727584839e-08,
247
+ "num_tokens": 320249.0,
248
+ "reward": 0.07908838111907243,
249
+ "reward_std": 0.07920666746795177,
250
+ "rewards/compute_reward/mean": 0.07908838111907243,
251
+ "rewards/compute_reward/std": 0.07920666970312595,
252
+ "step": 90,
253
+ "step_time": 10.337220244196942
254
+ },
255
+ {
256
+ "clip_ratio/high_max": 0.0,
257
+ "clip_ratio/high_mean": 0.0,
258
+ "clip_ratio/low_mean": 0.0,
259
+ "clip_ratio/low_min": 0.0,
260
+ "clip_ratio/region_mean": 0.0,
261
+ "completions/clipped_ratio": 1.0,
262
+ "completions/max_length": 100.0,
263
+ "completions/max_terminated_length": 0.0,
264
+ "completions/mean_length": 100.0,
265
+ "completions/mean_terminated_length": 0.0,
266
+ "completions/min_length": 100.0,
267
+ "completions/min_terminated_length": 0.0,
268
+ "entropy": 1.4064194440841675,
269
+ "epoch": 0.1,
270
+ "frac_reward_zero_std": 0.0,
271
+ "grad_norm": 3.352966785430908,
272
+ "learning_rate": 2.5250000000000004e-06,
273
+ "loss": 8.493661880493164e-08,
274
+ "num_tokens": 355369.0,
275
+ "reward": 0.14763977155089378,
276
+ "reward_std": 0.07424246501177549,
277
+ "rewards/compute_reward/mean": 0.14763977155089378,
278
+ "rewards/compute_reward/std": 0.0742424676194787,
279
+ "step": 100,
280
+ "step_time": 10.74917738300719
281
+ }
282
+ ],
283
+ "logging_steps": 10,
284
+ "max_steps": 200,
285
+ "num_input_tokens_seen": 355369,
286
+ "num_train_epochs": 1,
287
+ "save_steps": 100,
288
+ "stateful_callbacks": {
289
+ "TrainerControl": {
290
+ "args": {
291
+ "should_epoch_stop": false,
292
+ "should_evaluate": false,
293
+ "should_log": false,
294
+ "should_save": true,
295
+ "should_training_stop": false
296
+ },
297
+ "attributes": {}
298
+ }
299
+ },
300
+ "total_flos": 0.0,
301
+ "train_batch_size": 2,
302
+ "trial_name": null,
303
+ "trial_params": null
304
+ }
training/checkpoints/unified_final/checkpoint-200/chat_template.jinja ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% for message in messages %}
2
+ {% if message['role'] == 'user' %}
3
+ {{ '<|user|>
4
+ ' + message['content'] + eos_token }}
5
+ {% elif message['role'] == 'system' %}
6
+ {{ '<|system|>
7
+ ' + message['content'] + eos_token }}
8
+ {% elif message['role'] == 'assistant' %}
9
+ {{ '<|assistant|>
10
+ ' + message['content'] + eos_token }}
11
+ {% endif %}
12
+ {% if loop.last and add_generation_prompt %}
13
+ {{ '<|assistant|>' }}
14
+ {% endif %}
15
+ {% endfor %}
training/checkpoints/unified_final/checkpoint-200/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5632,
15
+ "max_position_embeddings": 2048,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 22,
20
+ "num_key_value_heads": 4,
21
+ "pad_token_id": 2,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_parameters": {
25
+ "rope_theta": 10000.0,
26
+ "rope_type": "default"
27
+ },
28
+ "tie_word_embeddings": false,
29
+ "transformers_version": "5.3.0",
30
+ "use_cache": false,
31
+ "vocab_size": 32000
32
+ }
training/checkpoints/unified_final/checkpoint-200/generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": [
4
+ 2
5
+ ],
6
+ "max_length": 2048,
7
+ "pad_token_id": 2,
8
+ "transformers_version": "5.3.0"
9
+ }
training/checkpoints/unified_final/checkpoint-200/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
training/checkpoints/unified_final/checkpoint-200/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "</s>",
7
+ "is_local": true,
8
+ "max_length": null,
9
+ "model_max_length": 2048,
10
+ "pad_to_multiple_of": null,
11
+ "pad_token": "</s>",
12
+ "pad_token_type_id": 0,
13
+ "padding_side": "left",
14
+ "sp_model_kwargs": {},
15
+ "tokenizer_class": "LlamaTokenizer",
16
+ "truncation_side": "left",
17
+ "unk_token": "<unk>",
18
+ "use_default_system_prompt": false
19
+ }
training/checkpoints/unified_final/checkpoint-200/trainer_state.json ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.2,
6
+ "eval_steps": 500,
7
+ "global_step": 200,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "clip_ratio/high_max": 0.0,
14
+ "clip_ratio/high_mean": 0.0,
15
+ "clip_ratio/low_mean": 0.0,
16
+ "clip_ratio/low_min": 0.0,
17
+ "clip_ratio/region_mean": 0.0,
18
+ "completions/clipped_ratio": 1.0,
19
+ "completions/max_length": 100.0,
20
+ "completions/max_terminated_length": 0.0,
21
+ "completions/mean_length": 100.0,
22
+ "completions/mean_terminated_length": 0.0,
23
+ "completions/min_length": 100.0,
24
+ "completions/min_terminated_length": 0.0,
25
+ "entropy": 1.3566992908716202,
26
+ "epoch": 0.01,
27
+ "frac_reward_zero_std": 0.0,
28
+ "grad_norm": 0.7344621419906616,
29
+ "learning_rate": 4.775e-06,
30
+ "loss": 3.0994415283203126e-07,
31
+ "num_tokens": 35800.0,
32
+ "reward": 0.01268580500036478,
33
+ "reward_std": 0.02462496655061841,
34
+ "rewards/compute_reward/mean": 0.01268580500036478,
35
+ "rewards/compute_reward/std": 0.024624967435374855,
36
+ "step": 10,
37
+ "step_time": 10.7134718033034
38
+ },
39
+ {
40
+ "clip_ratio/high_max": 0.0,
41
+ "clip_ratio/high_mean": 0.0,
42
+ "clip_ratio/low_mean": 0.0,
43
+ "clip_ratio/low_min": 0.0,
44
+ "clip_ratio/region_mean": 0.0,
45
+ "completions/clipped_ratio": 1.0,
46
+ "completions/max_length": 100.0,
47
+ "completions/max_terminated_length": 0.0,
48
+ "completions/mean_length": 100.0,
49
+ "completions/mean_terminated_length": 0.0,
50
+ "completions/min_length": 100.0,
51
+ "completions/min_terminated_length": 0.0,
52
+ "entropy": 1.3169427752494811,
53
+ "epoch": 0.02,
54
+ "frac_reward_zero_std": 0.0,
55
+ "grad_norm": 4.441726207733154,
56
+ "learning_rate": 4.525000000000001e-06,
57
+ "loss": -4.246830940246582e-07,
58
+ "num_tokens": 71748.0,
59
+ "reward": -0.04455982223153114,
60
+ "reward_std": 0.035665383422747256,
61
+ "rewards/compute_reward/mean": -0.04455982223153114,
62
+ "rewards/compute_reward/std": 0.03566538490122184,
63
+ "step": 20,
64
+ "step_time": 10.643421414200565
65
+ },
66
+ {
67
+ "clip_ratio/high_max": 0.0,
68
+ "clip_ratio/high_mean": 0.0,
69
+ "clip_ratio/low_mean": 0.0,
70
+ "clip_ratio/low_min": 0.0,
71
+ "clip_ratio/region_mean": 0.0,
72
+ "completions/clipped_ratio": 0.9875,
73
+ "completions/max_length": 100.0,
74
+ "completions/max_terminated_length": 8.9,
75
+ "completions/mean_length": 99.8625,
76
+ "completions/mean_terminated_length": 8.9,
77
+ "completions/min_length": 98.9,
78
+ "completions/min_terminated_length": 8.9,
79
+ "entropy": 1.0057833462953567,
80
+ "epoch": 0.03,
81
+ "frac_reward_zero_std": 0.0,
82
+ "grad_norm": 3.0170326232910156,
83
+ "learning_rate": 4.2750000000000006e-06,
84
+ "loss": -0.0018164031207561493,
85
+ "num_tokens": 108181.0,
86
+ "reward": 0.0374881561845541,
87
+ "reward_std": 0.020618790527805686,
88
+ "rewards/compute_reward/mean": 0.0374881561845541,
89
+ "rewards/compute_reward/std": 0.0206187907140702,
90
+ "step": 30,
91
+ "step_time": 10.756140169796709
92
+ },
93
+ {
94
+ "clip_ratio/high_max": 0.0,
95
+ "clip_ratio/high_mean": 0.0,
96
+ "clip_ratio/low_mean": 0.0,
97
+ "clip_ratio/low_min": 0.0,
98
+ "clip_ratio/region_mean": 0.0,
99
+ "completions/clipped_ratio": 0.9875,
100
+ "completions/max_length": 100.0,
101
+ "completions/max_terminated_length": 6.6,
102
+ "completions/mean_length": 99.575,
103
+ "completions/mean_terminated_length": 6.6,
104
+ "completions/min_length": 96.6,
105
+ "completions/min_terminated_length": 6.6,
106
+ "entropy": 1.7816664546728134,
107
+ "epoch": 0.04,
108
+ "frac_reward_zero_std": 0.0,
109
+ "grad_norm": 5.86561393737793,
110
+ "learning_rate": 4.0250000000000004e-06,
111
+ "loss": -0.006361240148544311,
112
+ "num_tokens": 143375.0,
113
+ "reward": -0.014824284799396991,
114
+ "reward_std": 0.06699581742286682,
115
+ "rewards/compute_reward/mean": -0.014824284799396991,
116
+ "rewards/compute_reward/std": 0.06699582003057003,
117
+ "step": 40,
118
+ "step_time": 10.785410385398427
119
+ },
120
+ {
121
+ "clip_ratio/high_max": 0.0,
122
+ "clip_ratio/high_mean": 0.0,
123
+ "clip_ratio/low_mean": 0.0,
124
+ "clip_ratio/low_min": 0.0,
125
+ "clip_ratio/region_mean": 0.0,
126
+ "completions/clipped_ratio": 0.9875,
127
+ "completions/max_length": 100.0,
128
+ "completions/max_terminated_length": 3.0,
129
+ "completions/mean_length": 99.125,
130
+ "completions/mean_terminated_length": 3.0,
131
+ "completions/min_length": 93.0,
132
+ "completions/min_terminated_length": 3.0,
133
+ "entropy": 2.1307705104351045,
134
+ "epoch": 0.05,
135
+ "frac_reward_zero_std": 0.0,
136
+ "grad_norm": 6.191352367401123,
137
+ "learning_rate": 3.7750000000000003e-06,
138
+ "loss": -0.011027154326438905,
139
+ "num_tokens": 178941.0,
140
+ "reward": -0.016337488451972602,
141
+ "reward_std": 0.051818730868399145,
142
+ "rewards/compute_reward/mean": -0.016337488451972602,
143
+ "rewards/compute_reward/std": 0.05181873142719269,
144
+ "step": 50,
145
+ "step_time": 10.741381045605522
146
+ },
147
+ {
148
+ "clip_ratio/high_max": 0.0,
149
+ "clip_ratio/high_mean": 0.0,
150
+ "clip_ratio/low_mean": 0.0,
151
+ "clip_ratio/low_min": 0.0,
152
+ "clip_ratio/region_mean": 0.0,
153
+ "completions/clipped_ratio": 0.9875,
154
+ "completions/max_length": 100.0,
155
+ "completions/max_terminated_length": 8.8,
156
+ "completions/mean_length": 99.85,
157
+ "completions/mean_terminated_length": 8.8,
158
+ "completions/min_length": 98.8,
159
+ "completions/min_terminated_length": 8.8,
160
+ "entropy": 2.1041357040405275,
161
+ "epoch": 0.06,
162
+ "frac_reward_zero_std": 0.0,
163
+ "grad_norm": 8.536041259765625,
164
+ "learning_rate": 3.525e-06,
165
+ "loss": 0.0019509844481945039,
166
+ "num_tokens": 216257.0,
167
+ "reward": 0.035917540453374384,
168
+ "reward_std": 0.04930563308298588,
169
+ "rewards/compute_reward/mean": 0.035917540453374384,
170
+ "rewards/compute_reward/std": 0.049305635318160054,
171
+ "step": 60,
172
+ "step_time": 11.27133785020269
173
+ },
174
+ {
175
+ "clip_ratio/high_max": 0.0,
176
+ "clip_ratio/high_mean": 0.0,
177
+ "clip_ratio/low_mean": 0.0,
178
+ "clip_ratio/low_min": 0.0,
179
+ "clip_ratio/region_mean": 0.0,
180
+ "completions/clipped_ratio": 0.8,
181
+ "completions/max_length": 100.0,
182
+ "completions/max_terminated_length": 48.2,
183
+ "completions/mean_length": 92.9625,
184
+ "completions/mean_terminated_length": 38.51333351135254,
185
+ "completions/min_length": 70.1,
186
+ "completions/min_terminated_length": 30.1,
187
+ "entropy": 1.6469052851200103,
188
+ "epoch": 0.07,
189
+ "frac_reward_zero_std": 0.0,
190
+ "grad_norm": 6.919373512268066,
191
+ "learning_rate": 3.2750000000000004e-06,
192
+ "loss": -0.02075239419937134,
193
+ "num_tokens": 251110.0,
194
+ "reward": 0.007261525164358318,
195
+ "reward_std": 0.0802696269005537,
196
+ "rewards/compute_reward/mean": 0.007261525164358318,
197
+ "rewards/compute_reward/std": 0.08026962876319885,
198
+ "step": 70,
199
+ "step_time": 10.774873650902009
200
+ },
201
+ {
202
+ "clip_ratio/high_max": 0.0,
203
+ "clip_ratio/high_mean": 0.0,
204
+ "clip_ratio/low_mean": 0.0,
205
+ "clip_ratio/low_min": 0.0,
206
+ "clip_ratio/region_mean": 0.0,
207
+ "completions/clipped_ratio": 0.9875,
208
+ "completions/max_length": 100.0,
209
+ "completions/max_terminated_length": 3.1,
210
+ "completions/mean_length": 99.1375,
211
+ "completions/mean_terminated_length": 3.1,
212
+ "completions/min_length": 93.1,
213
+ "completions/min_terminated_length": 3.1,
214
+ "entropy": 2.2336367428302766,
215
+ "epoch": 0.08,
216
+ "frac_reward_zero_std": 0.0,
217
+ "grad_norm": 4.918172836303711,
218
+ "learning_rate": 3.0250000000000003e-06,
219
+ "loss": 0.008250368386507034,
220
+ "num_tokens": 285729.0,
221
+ "reward": 0.027657157555222512,
222
+ "reward_std": 0.04840414375066757,
223
+ "rewards/compute_reward/mean": 0.027657157555222512,
224
+ "rewards/compute_reward/std": 0.048404145427048205,
225
+ "step": 80,
226
+ "step_time": 10.43483721170196
227
+ },
228
+ {
229
+ "clip_ratio/high_max": 0.0,
230
+ "clip_ratio/high_mean": 0.0,
231
+ "clip_ratio/low_mean": 0.0,
232
+ "clip_ratio/low_min": 0.0,
233
+ "clip_ratio/region_mean": 0.0,
234
+ "completions/clipped_ratio": 1.0,
235
+ "completions/max_length": 100.0,
236
+ "completions/max_terminated_length": 0.0,
237
+ "completions/mean_length": 100.0,
238
+ "completions/mean_terminated_length": 0.0,
239
+ "completions/min_length": 100.0,
240
+ "completions/min_terminated_length": 0.0,
241
+ "entropy": 1.8057245463132858,
242
+ "epoch": 0.09,
243
+ "frac_reward_zero_std": 0.0,
244
+ "grad_norm": 4.417481422424316,
245
+ "learning_rate": 2.7750000000000005e-06,
246
+ "loss": 2.216547727584839e-08,
247
+ "num_tokens": 320249.0,
248
+ "reward": 0.07908838111907243,
249
+ "reward_std": 0.07920666746795177,
250
+ "rewards/compute_reward/mean": 0.07908838111907243,
251
+ "rewards/compute_reward/std": 0.07920666970312595,
252
+ "step": 90,
253
+ "step_time": 10.337220244196942
254
+ },
255
+ {
256
+ "clip_ratio/high_max": 0.0,
257
+ "clip_ratio/high_mean": 0.0,
258
+ "clip_ratio/low_mean": 0.0,
259
+ "clip_ratio/low_min": 0.0,
260
+ "clip_ratio/region_mean": 0.0,
261
+ "completions/clipped_ratio": 1.0,
262
+ "completions/max_length": 100.0,
263
+ "completions/max_terminated_length": 0.0,
264
+ "completions/mean_length": 100.0,
265
+ "completions/mean_terminated_length": 0.0,
266
+ "completions/min_length": 100.0,
267
+ "completions/min_terminated_length": 0.0,
268
+ "entropy": 1.4064194440841675,
269
+ "epoch": 0.1,
270
+ "frac_reward_zero_std": 0.0,
271
+ "grad_norm": 3.352966785430908,
272
+ "learning_rate": 2.5250000000000004e-06,
273
+ "loss": 8.493661880493164e-08,
274
+ "num_tokens": 355369.0,
275
+ "reward": 0.14763977155089378,
276
+ "reward_std": 0.07424246501177549,
277
+ "rewards/compute_reward/mean": 0.14763977155089378,
278
+ "rewards/compute_reward/std": 0.0742424676194787,
279
+ "step": 100,
280
+ "step_time": 10.74917738300719
281
+ },
282
+ {
283
+ "clip_ratio/high_max": 0.0,
284
+ "clip_ratio/high_mean": 0.0,
285
+ "clip_ratio/low_mean": 0.0,
286
+ "clip_ratio/low_min": 0.0,
287
+ "clip_ratio/region_mean": 0.0,
288
+ "completions/clipped_ratio": 1.0,
289
+ "completions/max_length": 100.0,
290
+ "completions/max_terminated_length": 0.0,
291
+ "completions/mean_length": 100.0,
292
+ "completions/mean_terminated_length": 0.0,
293
+ "completions/min_length": 100.0,
294
+ "completions/min_terminated_length": 0.0,
295
+ "entropy": 1.2582464694976807,
296
+ "epoch": 0.11,
297
+ "frac_reward_zero_std": 0.0,
298
+ "grad_norm": 3.9595463275909424,
299
+ "learning_rate": 2.2750000000000002e-06,
300
+ "loss": -3.874301910400391e-08,
301
+ "num_tokens": 392289.0,
302
+ "reward": 0.18278183937072753,
303
+ "reward_std": 0.052620683796703815,
304
+ "rewards/compute_reward/mean": 0.18278183937072753,
305
+ "rewards/compute_reward/std": 0.05262068491429091,
306
+ "step": 110,
307
+ "step_time": 11.17140419179923
308
+ },
309
+ {
310
+ "clip_ratio/high_max": 0.0,
311
+ "clip_ratio/high_mean": 0.0,
312
+ "clip_ratio/low_mean": 0.0,
313
+ "clip_ratio/low_min": 0.0,
314
+ "clip_ratio/region_mean": 0.0,
315
+ "completions/clipped_ratio": 1.0,
316
+ "completions/max_length": 100.0,
317
+ "completions/max_terminated_length": 0.0,
318
+ "completions/mean_length": 100.0,
319
+ "completions/mean_terminated_length": 0.0,
320
+ "completions/min_length": 100.0,
321
+ "completions/min_terminated_length": 0.0,
322
+ "entropy": 0.8805452413856983,
323
+ "epoch": 0.12,
324
+ "frac_reward_zero_std": 0.0,
325
+ "grad_norm": 2.707214593887329,
326
+ "learning_rate": 2.025e-06,
327
+ "loss": 1.5050172805786132e-07,
328
+ "num_tokens": 430501.0,
329
+ "reward": 0.22903144657611846,
330
+ "reward_std": 0.04029850559309125,
331
+ "rewards/compute_reward/mean": 0.22903144657611846,
332
+ "rewards/compute_reward/std": 0.04029850568622351,
333
+ "step": 120,
334
+ "step_time": 11.244449263699062
335
+ },
336
+ {
337
+ "clip_ratio/high_max": 0.0,
338
+ "clip_ratio/high_mean": 0.0,
339
+ "clip_ratio/low_mean": 0.0,
340
+ "clip_ratio/low_min": 0.0,
341
+ "clip_ratio/region_mean": 0.0,
342
+ "completions/clipped_ratio": 1.0,
343
+ "completions/max_length": 100.0,
344
+ "completions/max_terminated_length": 0.0,
345
+ "completions/mean_length": 100.0,
346
+ "completions/mean_terminated_length": 0.0,
347
+ "completions/min_length": 100.0,
348
+ "completions/min_terminated_length": 0.0,
349
+ "entropy": 0.8755271568894386,
350
+ "epoch": 0.13,
351
+ "frac_reward_zero_std": 0.0,
352
+ "grad_norm": 3.942605495452881,
353
+ "learning_rate": 1.7750000000000002e-06,
354
+ "loss": 1.2218952178955077e-07,
355
+ "num_tokens": 467245.0,
356
+ "reward": 0.18334048390388488,
357
+ "reward_std": 0.07254596166312695,
358
+ "rewards/compute_reward/mean": 0.18334048390388488,
359
+ "rewards/compute_reward/std": 0.072545962408185,
360
+ "step": 130,
361
+ "step_time": 11.071729802998016
362
+ },
363
+ {
364
+ "clip_ratio/high_max": 0.0,
365
+ "clip_ratio/high_mean": 0.0,
366
+ "clip_ratio/low_mean": 0.0,
367
+ "clip_ratio/low_min": 0.0,
368
+ "clip_ratio/region_mean": 0.0,
369
+ "completions/clipped_ratio": 1.0,
370
+ "completions/max_length": 100.0,
371
+ "completions/max_terminated_length": 0.0,
372
+ "completions/mean_length": 100.0,
373
+ "completions/mean_terminated_length": 0.0,
374
+ "completions/min_length": 100.0,
375
+ "completions/min_terminated_length": 0.0,
376
+ "entropy": 0.9737002968788147,
377
+ "epoch": 0.14,
378
+ "frac_reward_zero_std": 0.0,
379
+ "grad_norm": 4.040837287902832,
380
+ "learning_rate": 1.525e-06,
381
+ "loss": -1.4007091522216797e-07,
382
+ "num_tokens": 503017.0,
383
+ "reward": 0.20783505886793135,
384
+ "reward_std": 0.06580547224730253,
385
+ "rewards/compute_reward/mean": 0.20783505886793135,
386
+ "rewards/compute_reward/std": 0.06580547466874123,
387
+ "step": 140,
388
+ "step_time": 10.841636341501726
389
+ },
390
+ {
391
+ "clip_ratio/high_max": 0.0,
392
+ "clip_ratio/high_mean": 0.0,
393
+ "clip_ratio/low_mean": 0.0,
394
+ "clip_ratio/low_min": 0.0,
395
+ "clip_ratio/region_mean": 0.0,
396
+ "completions/clipped_ratio": 1.0,
397
+ "completions/max_length": 100.0,
398
+ "completions/max_terminated_length": 0.0,
399
+ "completions/mean_length": 100.0,
400
+ "completions/mean_terminated_length": 0.0,
401
+ "completions/min_length": 100.0,
402
+ "completions/min_terminated_length": 0.0,
403
+ "entropy": 0.9901166066527367,
404
+ "epoch": 0.15,
405
+ "frac_reward_zero_std": 0.0,
406
+ "grad_norm": 3.720881462097168,
407
+ "learning_rate": 1.275e-06,
408
+ "loss": 2.0861625671386717e-08,
409
+ "num_tokens": 539801.0,
410
+ "reward": 0.2224348157644272,
411
+ "reward_std": 0.05879365894943476,
412
+ "rewards/compute_reward/mean": 0.2224348157644272,
413
+ "rewards/compute_reward/std": 0.05879366043955088,
414
+ "step": 150,
415
+ "step_time": 10.85469058619783
416
+ },
417
+ {
418
+ "clip_ratio/high_max": 0.0,
419
+ "clip_ratio/high_mean": 0.0,
420
+ "clip_ratio/low_mean": 0.0,
421
+ "clip_ratio/low_min": 0.0,
422
+ "clip_ratio/region_mean": 0.0,
423
+ "completions/clipped_ratio": 1.0,
424
+ "completions/max_length": 100.0,
425
+ "completions/max_terminated_length": 0.0,
426
+ "completions/mean_length": 100.0,
427
+ "completions/mean_terminated_length": 0.0,
428
+ "completions/min_length": 100.0,
429
+ "completions/min_terminated_length": 0.0,
430
+ "entropy": 1.1208710052073,
431
+ "epoch": 0.16,
432
+ "frac_reward_zero_std": 0.0,
433
+ "grad_norm": 3.452557325363159,
434
+ "learning_rate": 1.025e-06,
435
+ "loss": 1.4603137969970704e-07,
436
+ "num_tokens": 575385.0,
437
+ "reward": 0.1992661789059639,
438
+ "reward_std": 0.06030977526679635,
439
+ "rewards/compute_reward/mean": 0.1992661789059639,
440
+ "rewards/compute_reward/std": 0.060309774987399575,
441
+ "step": 160,
442
+ "step_time": 10.620040459206212
443
+ },
444
+ {
445
+ "clip_ratio/high_max": 0.0,
446
+ "clip_ratio/high_mean": 0.0,
447
+ "clip_ratio/low_mean": 0.0,
448
+ "clip_ratio/low_min": 0.0,
449
+ "clip_ratio/region_mean": 0.0,
450
+ "completions/clipped_ratio": 0.9875,
451
+ "completions/max_length": 100.0,
452
+ "completions/max_terminated_length": 8.5,
453
+ "completions/mean_length": 99.8125,
454
+ "completions/mean_terminated_length": 8.5,
455
+ "completions/min_length": 98.5,
456
+ "completions/min_terminated_length": 8.5,
457
+ "entropy": 0.943237779289484,
458
+ "epoch": 0.17,
459
+ "frac_reward_zero_std": 0.0,
460
+ "grad_norm": 3.998199701309204,
461
+ "learning_rate": 7.750000000000001e-07,
462
+ "loss": 0.0005225777626037597,
463
+ "num_tokens": 611998.0,
464
+ "reward": 0.21552147567272187,
465
+ "reward_std": 0.032230423856526615,
466
+ "rewards/compute_reward/mean": 0.21552147567272187,
467
+ "rewards/compute_reward/std": 0.0322304243221879,
468
+ "step": 170,
469
+ "step_time": 10.901679297701047
470
+ },
471
+ {
472
+ "clip_ratio/high_max": 0.0,
473
+ "clip_ratio/high_mean": 0.0,
474
+ "clip_ratio/low_mean": 0.0,
475
+ "clip_ratio/low_min": 0.0,
476
+ "clip_ratio/region_mean": 0.0,
477
+ "completions/clipped_ratio": 1.0,
478
+ "completions/max_length": 100.0,
479
+ "completions/max_terminated_length": 0.0,
480
+ "completions/mean_length": 100.0,
481
+ "completions/mean_terminated_length": 0.0,
482
+ "completions/min_length": 100.0,
483
+ "completions/min_terminated_length": 0.0,
484
+ "entropy": 0.9798725090920926,
485
+ "epoch": 0.18,
486
+ "frac_reward_zero_std": 0.0,
487
+ "grad_norm": 3.732668161392212,
488
+ "learning_rate": 5.250000000000001e-07,
489
+ "loss": -8.270144462585449e-08,
490
+ "num_tokens": 647338.0,
491
+ "reward": 0.21226384192705156,
492
+ "reward_std": 0.06548679377883673,
493
+ "rewards/compute_reward/mean": 0.21226384192705156,
494
+ "rewards/compute_reward/std": 0.0654867960140109,
495
+ "step": 180,
496
+ "step_time": 10.853807216498534
497
+ },
498
+ {
499
+ "clip_ratio/high_max": 0.0,
500
+ "clip_ratio/high_mean": 0.0,
501
+ "clip_ratio/low_mean": 0.0,
502
+ "clip_ratio/low_min": 0.0,
503
+ "clip_ratio/region_mean": 0.0,
504
+ "completions/clipped_ratio": 1.0,
505
+ "completions/max_length": 100.0,
506
+ "completions/max_terminated_length": 0.0,
507
+ "completions/mean_length": 100.0,
508
+ "completions/mean_terminated_length": 0.0,
509
+ "completions/min_length": 100.0,
510
+ "completions/min_terminated_length": 0.0,
511
+ "entropy": 0.9461549550294877,
512
+ "epoch": 0.19,
513
+ "frac_reward_zero_std": 0.0,
514
+ "grad_norm": 3.7145590782165527,
515
+ "learning_rate": 2.75e-07,
516
+ "loss": -2.1532177925109862e-07,
517
+ "num_tokens": 682026.0,
518
+ "reward": 0.21948475018143654,
519
+ "reward_std": 0.05461370516568422,
520
+ "rewards/compute_reward/mean": 0.21948475018143654,
521
+ "rewards/compute_reward/std": 0.05461370553821325,
522
+ "step": 190,
523
+ "step_time": 10.456350517399551
524
+ },
525
+ {
526
+ "clip_ratio/high_max": 0.0,
527
+ "clip_ratio/high_mean": 0.0,
528
+ "clip_ratio/low_mean": 0.0,
529
+ "clip_ratio/low_min": 0.0,
530
+ "clip_ratio/region_mean": 0.0,
531
+ "completions/clipped_ratio": 1.0,
532
+ "completions/max_length": 100.0,
533
+ "completions/max_terminated_length": 0.0,
534
+ "completions/mean_length": 100.0,
535
+ "completions/mean_terminated_length": 0.0,
536
+ "completions/min_length": 100.0,
537
+ "completions/min_terminated_length": 0.0,
538
+ "entropy": 0.8442220821976661,
539
+ "epoch": 0.2,
540
+ "frac_reward_zero_std": 0.0,
541
+ "grad_norm": 3.7965171337127686,
542
+ "learning_rate": 2.5000000000000002e-08,
543
+ "loss": 1.0430812835693359e-08,
544
+ "num_tokens": 716746.0,
545
+ "reward": 0.2305009976029396,
546
+ "reward_std": 0.03879760131239891,
547
+ "rewards/compute_reward/mean": 0.2305009976029396,
548
+ "rewards/compute_reward/std": 0.03879760047420859,
549
+ "step": 200,
550
+ "step_time": 10.340635509999993
551
+ }
552
+ ],
553
+ "logging_steps": 10,
554
+ "max_steps": 200,
555
+ "num_input_tokens_seen": 716746,
556
+ "num_train_epochs": 1,
557
+ "save_steps": 100,
558
+ "stateful_callbacks": {
559
+ "TrainerControl": {
560
+ "args": {
561
+ "should_epoch_stop": false,
562
+ "should_evaluate": false,
563
+ "should_log": false,
564
+ "should_save": true,
565
+ "should_training_stop": true
566
+ },
567
+ "attributes": {}
568
+ }
569
+ },
570
+ "total_flos": 0.0,
571
+ "train_batch_size": 2,
572
+ "trial_name": null,
573
+ "trial_params": null
574
+ }
training/checkpoints/unified_final/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "head_dim": 64,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 2048,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5632,
15
+ "max_position_embeddings": 2048,
16
+ "mlp_bias": false,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 22,
20
+ "num_key_value_heads": 4,
21
+ "pad_token_id": 2,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_parameters": {
25
+ "rope_theta": 10000.0,
26
+ "rope_type": "default"
27
+ },
28
+ "tie_word_embeddings": false,
29
+ "transformers_version": "5.3.0",
30
+ "use_cache": false,
31
+ "vocab_size": 32000
32
+ }
training/checkpoints/unified_final/generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": [
4
+ 2
5
+ ],
6
+ "max_length": 2048,
7
+ "pad_token_id": 2,
8
+ "transformers_version": "5.3.0"
9
+ }
training/checkpoints/unified_final/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
training/checkpoints/unified_final/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "</s>",
7
+ "is_local": true,
8
+ "max_length": null,
9
+ "model_max_length": 2048,
10
+ "pad_to_multiple_of": null,
11
+ "pad_token": "</s>",
12
+ "pad_token_type_id": 0,
13
+ "padding_side": "left",
14
+ "sp_model_kwargs": {},
15
+ "tokenizer_class": "LlamaTokenizer",
16
+ "truncation_side": "left",
17
+ "unk_token": "<unk>",
18
+ "use_default_system_prompt": false
19
+ }
training/checkpoints/unified_final/unified_reward_log.json ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "accuracy": [
3
+ 0.012478123821101302,
4
+ 0.013689774048328765,
5
+ 0.12357050236883002,
6
+ 0.043150096433237195,
7
+ 0.11808098944816375,
8
+ 0.14478551750907398,
9
+ 0.21936089415676943,
10
+ 0.14560732765872023,
11
+ 0.12766012796254073,
12
+ 0.16228250732999258,
13
+ 0.19256023689530533,
14
+ 0.153446869824083,
15
+ 0.08735395734236795,
16
+ 0.25620539761275585,
17
+ 0.2796424323605421,
18
+ 0.4050695781981913,
19
+ 0.34320680785281277,
20
+ 0.39042326634482405,
21
+ 0.24141882976569753,
22
+ 0.2882491476114424,
23
+ 0.2805112680700598,
24
+ 0.1299182187184869,
25
+ 0.18283964773559502,
26
+ 0.08174918994377885,
27
+ 0.1305077084983307,
28
+ 0.15188368799701088,
29
+ 0.10731278214010087,
30
+ 0.10817607256366782,
31
+ 0.1742403849902705,
32
+ 0.15966549523684162,
33
+ 0.21224383614993403,
34
+ 0.30634267989144903,
35
+ 0.2563189622014761,
36
+ 0.13088561721084532,
37
+ 0.23896305011421776,
38
+ 0.36338720554077614,
39
+ 0.2743395734578371,
40
+ 0.2785670698390685,
41
+ 0.26690704237418583,
42
+ 0.23420825800444123,
43
+ 0.4486492634482796,
44
+ 0.3085314377908274,
45
+ 0.27236165767163295,
46
+ 0.351135627192783,
47
+ 0.37157259147763155,
48
+ 0.4091061054548437,
49
+ 0.3321387716436809,
50
+ 0.25690332708634805,
51
+ 0.4042620632377111,
52
+ 0.21426805183517378,
53
+ 0.46486986328175767,
54
+ 0.5354255396266014,
55
+ 0.5316739152617584,
56
+ 0.3626249278251227,
57
+ 0.5560084815324287,
58
+ 0.47374602488847506,
59
+ 0.5622030981309204,
60
+ 0.6260334739834723,
61
+ 0.5388746766273916,
62
+ 0.43546972183358157,
63
+ 0.4384314355118149,
64
+ 0.43255371653260083,
65
+ 0.382003842773009,
66
+ 0.33916141995282467,
67
+ 0.4102824234143368,
68
+ 0.4002692943218704,
69
+ 0.4433627484561765,
70
+ 0.5707634448719365,
71
+ 0.3326736211199734,
72
+ 0.41868448313128437,
73
+ 0.4830820909726724,
74
+ 0.5073173724203757,
75
+ 0.6011403764343056,
76
+ 0.2652010267221505,
77
+ 0.5708498617899997,
78
+ 0.5372080254474398,
79
+ 0.34268688791221447,
80
+ 0.36077516272765764,
81
+ 0.6577040443039563,
82
+ 0.5249539674929385,
83
+ 0.3393068936409599,
84
+ 0.3981918416905377,
85
+ 0.5998766558760262,
86
+ 0.3886278953534839,
87
+ 0.47030574201103836,
88
+ 0.5933578772929455,
89
+ 0.629797753552287,
90
+ 0.6829957361516797,
91
+ 0.5975855789903534,
92
+ 0.37033629002672747,
93
+ 0.40129960235208273,
94
+ 0.44104763492941856,
95
+ 0.5250475457257945,
96
+ 0.5792574424612014,
97
+ 0.25491493314992414,
98
+ 0.4456432306425367,
99
+ 0.3674802188566988,
100
+ 0.5168529125349757,
101
+ 0.7135775878197881,
102
+ 0.408872426591652,
103
+ 0.29645813006976085,
104
+ 0.5807047440217663,
105
+ 0.3951396545427582,
106
+ 0.5820897600332913,
107
+ 0.5751887943251881,
108
+ 0.6462836385320105,
109
+ 0.452535930180199,
110
+ 0.6309295986678539,
111
+ 0.521345004487674,
112
+ 0.7523772581521466,
113
+ 0.3868275580258203,
114
+ 0.6621844534173644,
115
+ 0.757102247782526,
116
+ 0.7496667811480936,
117
+ 0.765902349873787,
118
+ 0.7620735178706088,
119
+ 0.8005386810387373,
120
+ 0.7600417191929723,
121
+ 0.7790964529097753,
122
+ 0.8060362095807505,
123
+ 0.6639245812548539,
124
+ 0.49642928937921477,
125
+ 0.4622820479255877,
126
+ 0.5039745619269863,
127
+ 0.5521504355740943,
128
+ 0.763103948879152,
129
+ 0.3649169562800698,
130
+ 0.8642640291197355,
131
+ 0.7673212948914258,
132
+ 0.6856467187291327,
133
+ 0.6203947744628628,
134
+ 0.635864180446877,
135
+ 0.7076110516058842,
136
+ 0.45257112707172986,
137
+ 0.4927382976084982,
138
+ 0.735338338570779,
139
+ 0.7325108773598185,
140
+ 0.5286115260781837,
141
+ 0.6873601944038981,
142
+ 0.7558585478414992,
143
+ 0.8025525164825894,
144
+ 0.5403924472630024,
145
+ 0.8109585656614495,
146
+ 0.45960476465808653,
147
+ 0.7726514123926349,
148
+ 0.78036072270019,
149
+ 0.5612159043391909,
150
+ 0.668619691132455,
151
+ 0.7187997825397312,
152
+ 0.6008389099901545,
153
+ 0.5160061409523324,
154
+ 0.6712722339255528,
155
+ 0.25213094055121654,
156
+ 0.7931299787283417,
157
+ 0.5770709363152806,
158
+ 0.3674653100689218,
159
+ 0.7533031922202384,
160
+ 0.5477579357220128,
161
+ 0.9013020257140825,
162
+ 0.774595058715597,
163
+ 0.5444791193214735,
164
+ 0.28536322558907645,
165
+ 0.8018009673613502,
166
+ 0.7534115956222964,
167
+ 0.8178817865612724,
168
+ 0.7691389758719754,
169
+ 0.746364161759599,
170
+ 0.7686015134039534,
171
+ 0.734219302571865,
172
+ 0.32221002464589255,
173
+ 0.47941368112339633,
174
+ 0.7168057798061833,
175
+ 0.772261652825011,
176
+ 0.5291935548529084,
177
+ 0.7485607594114032,
178
+ 0.5932522241567504,
179
+ 0.5648661194163807,
180
+ 0.5709367030781823,
181
+ 0.7752278802176389,
182
+ 0.6248770881515031,
183
+ 0.5446761697530746,
184
+ 0.8044651419608864,
185
+ 0.855248827897706,
186
+ 0.5436122580157401,
187
+ 0.9085174062877894,
188
+ 0.31500336882736524,
189
+ 0.6913784691774245,
190
+ 0.5400797382818436,
191
+ 0.6050753133365693,
192
+ 0.7986505120673587,
193
+ 0.8202528873914283,
194
+ 0.6996518377501237,
195
+ 0.8313200483947909,
196
+ 0.4808844911385792,
197
+ 0.7306097140061414,
198
+ 0.5058602896511918,
199
+ 0.6438089653119033,
200
+ 0.7879260241436392,
201
+ 0.8337068369817564,
202
+ 0.537435884385747
203
+ ],
204
+ "outcome": [
205
+ 0.4,
206
+ 0.42500000000000004,
207
+ 0.4375,
208
+ 0.42500000000000004,
209
+ 0.4,
210
+ 0.4,
211
+ 0.4,
212
+ 0.25,
213
+ 0.4,
214
+ 0.0,
215
+ 0.0,
216
+ 0.0,
217
+ 0.0,
218
+ 0.07500000000000001,
219
+ 0.025,
220
+ 0.07500000000000001,
221
+ 0.0,
222
+ 0.07500000000000001,
223
+ 0.05,
224
+ 0.07500000000000001,
225
+ 0.225,
226
+ 0.4,
227
+ 0.4,
228
+ 0.4,
229
+ 0.42500000000000004,
230
+ 0.4,
231
+ 0.4,
232
+ 0.4,
233
+ 0.4,
234
+ 0.4,
235
+ 0.35000000000000003,
236
+ 0.175,
237
+ 0.15,
238
+ 0.15000000000000002,
239
+ 0.07500000000000001,
240
+ 0.17500000000000002,
241
+ 0.1,
242
+ 0.0,
243
+ 0.05,
244
+ 0.07500000000000001,
245
+ 0.07500000000000001,
246
+ 0.07500000000000001,
247
+ 0.025,
248
+ 0.0,
249
+ 0.0,
250
+ 0.0,
251
+ 0.07500000000000001,
252
+ 0.15000000000000002,
253
+ 0.0,
254
+ 0.05,
255
+ 0.0,
256
+ 0.025,
257
+ 0.0,
258
+ 0.0,
259
+ 0.0,
260
+ 0.05,
261
+ 0.0,
262
+ 0.05,
263
+ 0.025,
264
+ 0.07500000000000001,
265
+ 0.0,
266
+ 0.05,
267
+ 0.025,
268
+ 0.1,
269
+ 0.025,
270
+ 0.025,
271
+ 0.025,
272
+ 0.025,
273
+ 0.0,
274
+ 0.05,
275
+ 0.05,
276
+ 0.0,
277
+ 0.05,
278
+ 0.0,
279
+ 0.0,
280
+ 0.025,
281
+ 0.05,
282
+ 0.025,
283
+ 0.0,
284
+ 0.025,
285
+ 0.05,
286
+ 0.07500000000000001,
287
+ 0.125,
288
+ 0.25,
289
+ 0.125,
290
+ 0.2,
291
+ 0.05,
292
+ 0.17500000000000002,
293
+ 0.225,
294
+ 0.2,
295
+ 0.30000000000000004,
296
+ 0.375,
297
+ 0.35,
298
+ 0.42500000000000004,
299
+ 0.35000000000000003,
300
+ 0.42500000000000004,
301
+ 0.4,
302
+ 0.4,
303
+ 0.4,
304
+ 0.42500000000000004,
305
+ 0.42500000000000004,
306
+ 0.45,
307
+ 0.4,
308
+ 0.4,
309
+ 0.4,
310
+ 0.4,
311
+ 0.4,
312
+ 0.4,
313
+ 0.45,
314
+ 0.35000000000000003,
315
+ 0.4,
316
+ 0.4,
317
+ 0.4,
318
+ 0.35000000000000003,
319
+ 0.4,
320
+ 0.4,
321
+ 0.25,
322
+ 0.25,
323
+ 0.35000000000000003,
324
+ 0.4,
325
+ 0.35000000000000003,
326
+ 0.30000000000000004,
327
+ 0.4,
328
+ 0.35000000000000003,
329
+ 0.35000000000000003,
330
+ 0.35000000000000003,
331
+ 0.4,
332
+ 0.35000000000000003,
333
+ 0.35000000000000003,
334
+ 0.2,
335
+ 0.35000000000000003,
336
+ 0.4,
337
+ 0.35000000000000003,
338
+ 0.42500000000000004,
339
+ 0.4,
340
+ 0.30000000000000004,
341
+ 0.4,
342
+ 0.4,
343
+ 0.42500000000000004,
344
+ 0.42500000000000004,
345
+ 0.4,
346
+ 0.42500000000000004,
347
+ 0.4,
348
+ 0.4,
349
+ 0.35000000000000003,
350
+ 0.42500000000000004,
351
+ 0.30000000000000004,
352
+ 0.42500000000000004,
353
+ 0.4,
354
+ 0.4,
355
+ 0.4,
356
+ 0.42500000000000004,
357
+ 0.4,
358
+ 0.35000000000000003,
359
+ 0.4,
360
+ 0.42500000000000004,
361
+ 0.4,
362
+ 0.42500000000000004,
363
+ 0.25,
364
+ 0.35000000000000003,
365
+ 0.4,
366
+ 0.4,
367
+ 0.35000000000000003,
368
+ 0.4,
369
+ 0.4,
370
+ 0.35000000000000003,
371
+ 0.4,
372
+ 0.4,
373
+ 0.4,
374
+ 0.4,
375
+ 0.4,
376
+ 0.4,
377
+ 0.4,
378
+ 0.42500000000000004,
379
+ 0.4,
380
+ 0.4,
381
+ 0.4,
382
+ 0.375,
383
+ 0.4,
384
+ 0.375,
385
+ 0.4,
386
+ 0.35000000000000003,
387
+ 0.4,
388
+ 0.4,
389
+ 0.35000000000000003,
390
+ 0.42500000000000004,
391
+ 0.4,
392
+ 0.4,
393
+ 0.42500000000000004,
394
+ 0.4,
395
+ 0.4,
396
+ 0.4,
397
+ 0.4,
398
+ 0.45,
399
+ 0.4,
400
+ 0.4,
401
+ 0.4,
402
+ 0.35000000000000003,
403
+ 0.4,
404
+ 0.4
405
+ ],
406
+ "bluff": [
407
+ -0.5,
408
+ -0.5,
409
+ -0.5,
410
+ -0.5,
411
+ -0.5,
412
+ -0.5,
413
+ -0.5,
414
+ -0.5,
415
+ -0.5,
416
+ -0.5,
417
+ -0.5,
418
+ -0.5,
419
+ -0.5,
420
+ -0.5,
421
+ -0.5,
422
+ -0.5,
423
+ -0.5,
424
+ -0.5,
425
+ -0.5,
426
+ -0.5,
427
+ -0.5,
428
+ -0.5,
429
+ -0.5,
430
+ -0.5,
431
+ -0.5,
432
+ -0.5,
433
+ -0.5,
434
+ -0.5,
435
+ -0.5,
436
+ -0.5,
437
+ -0.5,
438
+ -0.5,
439
+ -0.5,
440
+ -0.5,
441
+ -0.5,
442
+ -0.5,
443
+ -0.5,
444
+ -0.5,
445
+ -0.5,
446
+ -0.5,
447
+ -0.5,
448
+ -0.5,
449
+ -0.5,
450
+ -0.5,
451
+ -0.5,
452
+ -0.5,
453
+ -0.5,
454
+ -0.5,
455
+ -0.5,
456
+ -0.5,
457
+ -0.5,
458
+ -0.5,
459
+ -0.5,
460
+ -0.5,
461
+ -0.5,
462
+ -0.5,
463
+ -0.5,
464
+ -0.5,
465
+ -0.5,
466
+ -0.5,
467
+ -0.5,
468
+ -0.5,
469
+ -0.5,
470
+ -0.5,
471
+ -0.5,
472
+ -0.5,
473
+ -0.5,
474
+ -0.5,
475
+ -0.5,
476
+ -0.5,
477
+ -0.5,
478
+ -0.5,
479
+ -0.5,
480
+ -0.5,
481
+ -0.5,
482
+ -0.5,
483
+ -0.5,
484
+ -0.5,
485
+ -0.5,
486
+ -0.5,
487
+ -0.5,
488
+ -0.5,
489
+ -0.5,
490
+ -0.5,
491
+ -0.5,
492
+ -0.5,
493
+ -0.5,
494
+ -0.5,
495
+ -0.5,
496
+ -0.5,
497
+ -0.5,
498
+ -0.5,
499
+ -0.5,
500
+ -0.5,
501
+ -0.5,
502
+ -0.5,
503
+ -0.5,
504
+ -0.5,
505
+ -0.5,
506
+ -0.5,
507
+ -0.5,
508
+ -0.5,
509
+ -0.5,
510
+ -0.5,
511
+ -0.5,
512
+ -0.5,
513
+ -0.5,
514
+ -0.5,
515
+ -0.5,
516
+ -0.5,
517
+ -0.5,
518
+ -0.5,
519
+ -0.5,
520
+ -0.5,
521
+ -0.5,
522
+ -0.5,
523
+ -0.5,
524
+ -0.5,
525
+ -0.5,
526
+ -0.5,
527
+ -0.5,
528
+ -0.5,
529
+ -0.5,
530
+ -0.5,
531
+ -0.5,
532
+ -0.5,
533
+ -0.5,
534
+ -0.5,
535
+ -0.5,
536
+ -0.5,
537
+ -0.5,
538
+ -0.5,
539
+ -0.5,
540
+ -0.5,
541
+ -0.5,
542
+ -0.5,
543
+ -0.5,
544
+ -0.5,
545
+ -0.5,
546
+ -0.5,
547
+ -0.5,
548
+ -0.5,
549
+ -0.5,
550
+ -0.5,
551
+ -0.5,
552
+ -0.5,
553
+ -0.5,
554
+ -0.5,
555
+ -0.5,
556
+ -0.5,
557
+ -0.5,
558
+ -0.5,
559
+ -0.5,
560
+ -0.5,
561
+ -0.5,
562
+ -0.5,
563
+ -0.5,
564
+ -0.5,
565
+ -0.5,
566
+ -0.5,
567
+ -0.5,
568
+ -0.5,
569
+ -0.5,
570
+ -0.5,
571
+ -0.5,
572
+ -0.5,
573
+ -0.5,
574
+ -0.5,
575
+ -0.5,
576
+ -0.5,
577
+ -0.5,
578
+ -0.5,
579
+ -0.5,
580
+ -0.5,
581
+ -0.5,
582
+ -0.5,
583
+ -0.5,
584
+ -0.5,
585
+ -0.5,
586
+ -0.5,
587
+ -0.5,
588
+ -0.5,
589
+ -0.5,
590
+ -0.5,
591
+ -0.5,
592
+ -0.5,
593
+ -0.5,
594
+ -0.5,
595
+ -0.5,
596
+ -0.5,
597
+ -0.5,
598
+ -0.5,
599
+ -0.5,
600
+ -0.5,
601
+ -0.5,
602
+ -0.5,
603
+ -0.5,
604
+ -0.5,
605
+ -0.5,
606
+ -0.5
607
+ ],
608
+ "total": [
609
+ -0.005632656662614553,
610
+ 0.0035414209169150612,
611
+ 0.0463746758290905,
612
+ 0.01385253375163301,
613
+ 0.031328346306857296,
614
+ 0.040674931128175884,
615
+ 0.06677631295486929,
616
+ -0.011537435319447932,
617
+ 0.03468104478688924,
618
+ -0.0932011224345026,
619
+ -0.08260391708664314,
620
+ -0.09629359556157094,
621
+ -0.11942611493017122,
622
+ -0.03407811083553544,
623
+ -0.04337514867381029,
624
+ 0.018024352369366947,
625
+ -0.02987761725151553,
626
+ 0.012898143220688411,
627
+ -0.04800340958200586,
628
+ -0.022862798335995183,
629
+ 0.026928943824520928,
630
+ 0.03547137655147041,
631
+ 0.05399387670745824,
632
+ 0.018612216480322585,
633
+ 0.044427697974415745,
634
+ 0.043159290798953795,
635
+ 0.027559473749035293,
636
+ 0.02786162539728372,
637
+ 0.05098413474659466,
638
+ 0.045882923332894544,
639
+ 0.04678534265247689,
640
+ 0.018469937962007153,
641
+ -0.007788363229483373,
642
+ -0.05169003397620414,
643
+ -0.04011293246002378,
644
+ 0.03843552193927165,
645
+ -0.018981149289757013,
646
+ -0.05250152555632605,
647
+ -0.039082535169034954,
648
+ -0.04177710969844557,
649
+ 0.033277242206897865,
650
+ -0.015763996773210408,
651
+ -0.045923419814928465,
652
+ -0.02710253048252593,
653
+ -0.019949592982828956,
654
+ -0.006812863090804698,
655
+ -0.007501429924711707,
656
+ -0.007583835519778186,
657
+ -0.008508277866801141,
658
+ -0.05750618185768919,
659
+ 0.012704452148615191,
660
+ 0.0461489388693105,
661
+ 0.036085870341615436,
662
+ -0.023081275261207068,
663
+ 0.04460296853635004,
664
+ 0.03331110871096628,
665
+ 0.04677108434582211,
666
+ 0.0866117158942153,
667
+ 0.04735613681958707,
668
+ 0.02866440264175356,
669
+ 0.0034510024291352186,
670
+ 0.01889380078641028,
671
+ -0.00754865502944687,
672
+ 0.0037064969834886344,
673
+ 0.0023488481950178913,
674
+ -0.001155746987345354,
675
+ 0.013926961959661782,
676
+ 0.058517205705177766,
677
+ -0.03356423260800931,
678
+ 0.014039569095949535,
679
+ 0.03657873184043532,
680
+ 0.02756108034713149,
681
+ 0.07789913175200697,
682
+ -0.05717964064724733,
683
+ 0.04979745162649989,
684
+ 0.04677280890660393,
685
+ -0.012559589230724939,
686
+ -0.014978693045319853,
687
+ 0.08019641550638473,
688
+ 0.04248388862252848,
689
+ -0.01374258722566403,
690
+ 0.015617144591688177,
691
+ 0.10370682955660918,
692
+ 0.07351976337371936,
693
+ 0.05835700970386343,
694
+ 0.12767525705253094,
695
+ 0.08792921374330046,
696
+ 0.1502985076530879,
697
+ 0.13790495264662364,
698
+ 0.049617701509354614,
699
+ 0.09545486082322892,
700
+ 0.13561667222529647,
701
+ 0.15626664100402804,
702
+ 0.2014901048614205,
703
+ 0.06172022660247342,
704
+ 0.15472513072488783,
705
+ 0.11861807659984457,
706
+ 0.1708985193872415,
707
+ 0.23975215573692582,
708
+ 0.1418553493070782,
709
+ 0.10251034552441629,
710
+ 0.21074666040761822,
711
+ 0.12829887908996535,
712
+ 0.19373141601165192,
713
+ 0.19131607801381584,
714
+ 0.21619927348620369,
715
+ 0.1483875755630696,
716
+ 0.2108253595337488,
717
+ 0.18997075157068588,
718
+ 0.23583204035325128,
719
+ 0.12538964530903712,
720
+ 0.22176455869607747,
721
+ 0.25498578672388406,
722
+ 0.2348833734018327,
723
+ 0.25806582245582543,
724
+ 0.256725731254713,
725
+ 0.217688538363558,
726
+ 0.20351460171754027,
727
+ 0.24518375851842128,
728
+ 0.2721126733532626,
729
+ 0.2048736034391988,
730
+ 0.12875025128272513,
731
+ 0.15179871677395568,
732
+ 0.14889109667444517,
733
+ 0.16575265245093296,
734
+ 0.23958638210770317,
735
+ 0.11772093469802442,
736
+ 0.27499241019190734,
737
+ 0.24106245321199898,
738
+ 0.15997635155519643,
739
+ 0.18963817106200198,
740
+ 0.21255246315640697,
741
+ 0.22016386806205945,
742
+ 0.1571498944751054,
743
+ 0.16245840416297436,
744
+ 0.21236841849977267,
745
+ 0.24637880707593643,
746
+ 0.17501403412736427,
747
+ 0.23932606804136433,
748
+ 0.2633004917445247,
749
+ 0.27089338076890623,
750
+ 0.1878873565420508,
751
+ 0.2738354979815073,
752
+ 0.15086166763033024,
753
+ 0.24292799433742218,
754
+ 0.27187625294506645,
755
+ 0.1514255665187168,
756
+ 0.2327668918963592,
757
+ 0.24157992388890587,
758
+ 0.20029361849655403,
759
+ 0.1706021493333163,
760
+ 0.23369528187394348,
761
+ 0.07824582919292578,
762
+ 0.25009549255491953,
763
+ 0.19197482771034816,
764
+ 0.1273628585241226,
765
+ 0.25365611727708337,
766
+ 0.19046527750270448,
767
+ 0.25295570899992886,
768
+ 0.24360827055045886,
769
+ 0.1805676917625157,
770
+ 0.08987712895617675,
771
+ 0.25313033857647255,
772
+ 0.25369405846780374,
773
+ 0.2762586252964453,
774
+ 0.24169864155519138,
775
+ 0.2512274566158596,
776
+ 0.25901052969138366,
777
+ 0.24697675590015272,
778
+ 0.10277350862606237,
779
+ 0.1577947883931887,
780
+ 0.2408820229321641,
781
+ 0.2602915784887538,
782
+ 0.1839677441985179,
783
+ 0.2519962657939911,
784
+ 0.19763827845486265,
785
+ 0.18770314179573322,
786
+ 0.1810778460773638,
787
+ 0.26132975807617365,
788
+ 0.1999569808530261,
789
+ 0.1806366594135761,
790
+ 0.2540627996863101,
791
+ 0.28933708976419703,
792
+ 0.18026429030550906,
793
+ 0.2904810922007262,
794
+ 0.10900117908957782,
795
+ 0.2319824642120985,
796
+ 0.17902790839864524,
797
+ 0.2105263596677992,
798
+ 0.26952767922357546,
799
+ 0.27708851058699985,
800
+ 0.23487814321254327,
801
+ 0.2809620169381768,
802
+ 0.1758095718985027,
803
+ 0.2457133999021494,
804
+ 0.1670511013779171,
805
+ 0.21533313785916613,
806
+ 0.2482741084502737,
807
+ 0.2817973929436147,
808
+ 0.1781025595350114
809
+ ]
810
+ }
training/parse_poker.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parse IRC poker pdb files and produce labeled bluff examples.
3
+
4
+ Line format: player_name timestamp num_players position preflop flop turn river bankroll won won2 [cards]
5
+ Action codes: f=fold, c=call, r=raise, b=bet, k=check, B=blind, -=no action
6
+ Cards at end of line = player went to showdown.
7
+
8
+ BLUFF = True: preflop has 'r' or 'b', hand ends in fold (last non-dash action ends in 'f'), no cards at end.
9
+ BLUFF = False: cards at end (showdown) OR folded with no aggression.
10
+ """
11
+
12
+ import json
13
+ import os
14
+ import re
15
+ from pathlib import Path
16
+
17
+ BASE_POKER = Path(__file__).resolve().parent / "data" / "poker"
18
+ PDB_DIR = BASE_POKER / "IRCdata" / "holdem" / "199901" / "pdb"
19
+ OUT_PATH = BASE_POKER / "bluff_labels.json"
20
+ MAX_EXAMPLES = 50_000
21
+
22
+ CARD_PATTERN = re.compile(r"^[2-9TJKQA][cdhs]$", re.IGNORECASE)
23
+
24
+
25
+ def _is_card_token(s: str) -> bool:
26
+ return bool(s and CARD_PATTERN.match(s.strip()))
27
+
28
+
29
+ def _has_cards_at_end(tokens: list) -> bool:
30
+ """True if line ends with card tokens (showdown)."""
31
+ if len(tokens) <= 11:
32
+ return False
33
+ # Last 1 or 2 tokens can be cards (e.g. "Ks Kh" or single card)
34
+ tail = tokens[11:]
35
+ return all(_is_card_token(t) for t in tail) and len(tail) >= 1
36
+
37
+
38
+ def _last_non_dash_ends_in_f(preflop: str, flop: str, turn: str, river: str) -> bool:
39
+ """Last non-dash action field ends in 'f' (fold)."""
40
+ for s in (river, turn, flop, preflop):
41
+ if s and s != "-":
42
+ return s.strip().endswith("f")
43
+ return False
44
+
45
+
46
+ def _preflop_aggressive(preflop: str) -> bool:
47
+ """Preflop contains raise or bet."""
48
+ return "r" in (preflop or "") or "b" in (preflop or "")
49
+
50
+
51
+ def parse_line(line: str) -> dict | None:
52
+ """
53
+ Returns {"text": str, "is_bluff": bool} or None if line invalid.
54
+ """
55
+ line = line.strip()
56
+ if not line:
57
+ return None
58
+ tokens = line.split()
59
+ if len(tokens) < 11:
60
+ return None
61
+ player_name = tokens[0]
62
+ timestamp = tokens[1]
63
+ num_players = tokens[2]
64
+ position = tokens[3]
65
+ preflop = tokens[4]
66
+ flop = tokens[5]
67
+ turn = tokens[6]
68
+ river = tokens[7]
69
+ bankroll = tokens[8]
70
+ won = tokens[9]
71
+ won2 = tokens[10]
72
+ try:
73
+ pot = abs(int(won))
74
+ except ValueError:
75
+ pot = 0
76
+
77
+ has_cards = _has_cards_at_end(tokens)
78
+ ends_in_fold = _last_non_dash_ends_in_f(preflop, flop, turn, river)
79
+ aggressive = _preflop_aggressive(preflop)
80
+
81
+ # BLUFF = True: aggressive preflop, ended in fold, no showdown
82
+ is_bluff = aggressive and ends_in_fold and not has_cards
83
+ # BLUFF = False: showdown OR fold with no aggression
84
+ if has_cards:
85
+ is_bluff = False
86
+ elif not aggressive and ends_in_fold:
87
+ is_bluff = False
88
+
89
+ text = (
90
+ f"Position {position} of {num_players}. "
91
+ f"Preflop: {preflop}. Flop: {flop}. Turn: {turn}. River: {river}. Pot: {pot}."
92
+ )
93
+ return {"text": text, "is_bluff": is_bluff}
94
+
95
+
96
+ def main():
97
+ os.makedirs(OUT_PATH.parent, exist_ok=True)
98
+ examples = []
99
+ # Files are named pdb.^, pdb.A2k, etc. (not *.pdb)
100
+ if PDB_DIR.exists():
101
+ pdb_files = [f for f in PDB_DIR.iterdir() if f.is_file() and f.name.startswith("pdb.")]
102
+ else:
103
+ pdb_files = []
104
+ if not pdb_files:
105
+ for d in BASE_POKER.rglob("pdb"):
106
+ if d.is_dir():
107
+ pdb_files.extend([f for f in d.iterdir() if f.is_file() and f.name.startswith("pdb.")])
108
+ if pdb_files:
109
+ break
110
+ if not pdb_files:
111
+ print(f"ERROR: No pdb files in {PDB_DIR} or under {BASE_POKER}")
112
+ return
113
+ for pdb_path in pdb_files:
114
+ if len(examples) >= MAX_EXAMPLES:
115
+ break
116
+ try:
117
+ with open(pdb_path, "r", encoding="utf-8", errors="replace") as f:
118
+ for line in f:
119
+ if len(examples) >= MAX_EXAMPLES:
120
+ break
121
+ rec = parse_line(line)
122
+ if rec is not None:
123
+ examples.append(rec)
124
+ except Exception as e:
125
+ print(f"Warning: {pdb_path}: {e}")
126
+ with open(OUT_PATH, "w") as f:
127
+ json.dump(examples, f, indent=0)
128
+ n = len(examples)
129
+ n_bluff = sum(1 for e in examples if e["is_bluff"])
130
+ print(f"Total examples: {n}")
131
+ print(f"Class balance: is_bluff=True {n_bluff}, is_bluff=False {n - n_bluff}")
132
+ print(f"Saved to {OUT_PATH}")
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
training/plot_phase2.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as mpatches
4
+
5
+ with open("training/checkpoints/phase2_final/checkpoint-200/trainer_state.json") as f:
6
+ state = json.load(f)
7
+
8
+ steps = [e["step"] for e in state["log_history"]]
9
+ rewards = [e["reward"] for e in state["log_history"]]
10
+
11
+ fig, ax = plt.subplots(figsize=(10, 5))
12
+ ax.plot(steps, rewards, color="#4C72B0", linewidth=2.5, marker="o", markersize=4)
13
+ ax.axhline(y=rewards[0], color="gray", linestyle="--", alpha=0.5, label=f"Start: {rewards[0]:.3f}")
14
+ ax.axhline(y=rewards[-1], color="#2ca02c", linestyle="--", alpha=0.5, label=f"End: {rewards[-1]:.3f}")
15
+ ax.fill_between(steps, rewards, rewards[0], alpha=0.1, color="#4C72B0")
16
+ ax.set_xlabel("Training Step", fontsize=13)
17
+ ax.set_ylabel("Mean Reward", fontsize=13)
18
+ ax.set_title("ArbitrAgent Phase 2 GRPO Training\nContractor Curriculum (Human Imitation)", fontsize=14)
19
+ ax.legend(fontsize=11)
20
+ ax.set_ylim(0, 0.5)
21
+ ax.grid(True, alpha=0.3)
22
+ plt.tight_layout()
23
+ plt.savefig("training/phase2_reward_curve.png", dpi=150)
24
+ print(f"Saved. Reward: {rewards[0]:.3f} → {rewards[-1]:.3f} over {steps[-1]} steps")
training/train_bluff_classifier.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train DistilBERT binary classifier on IRC poker bluff labels.
3
+
4
+ Data: training/data/poker/bluff_labels.json
5
+ Model: distilbert-base-uncased + linear 768→2
6
+ 80/20 train/val stratified, 3 epochs, lr 2e-5, batch 32
7
+ Saves: training/checkpoints/bluff_classifier.pt, bluff_classifier_tokenizer/
8
+ """
9
+
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from sklearn.model_selection import train_test_split
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from transformers import AutoTokenizer, AutoModel
19
+
20
+ SCRIPT_DIR = Path(__file__).resolve().parent
21
+ DATA_PATH = SCRIPT_DIR / "data" / "poker" / "bluff_labels.json"
22
+ CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints"
23
+ MODEL_PT = CHECKPOINT_DIR / "bluff_classifier.pt"
24
+ TOKENIZER_DIR = CHECKPOINT_DIR / "bluff_classifier_tokenizer"
25
+ MAX_LENGTH = 128
26
+ EPOCHS = 3
27
+ LR = 2e-5
28
+ BATCH_SIZE = 32
29
+
30
+
31
+ class BluffClassifier(nn.Module):
32
+ """DistilBERT + linear head 768 → 2 (binary: not_bluff, bluff)."""
33
+
34
+ def __init__(self, base_model: str = "distilbert-base-uncased"):
35
+ super().__init__()
36
+ self.encoder = AutoModel.from_pretrained(base_model)
37
+ hidden_size = self.encoder.config.hidden_size
38
+ self.head = nn.Linear(hidden_size, 2)
39
+
40
+ def forward(self, input_ids, attention_mask=None, **kwargs):
41
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
42
+ pooled = out.last_hidden_state[:, 0, :]
43
+ return self.head(pooled)
44
+
45
+
46
+ class BluffDataset(Dataset):
47
+ def __init__(self, texts, labels, tokenizer):
48
+ self.texts = texts
49
+ self.labels = labels
50
+ self.tokenizer = tokenizer
51
+
52
+ def __len__(self):
53
+ return len(self.texts)
54
+
55
+ def __getitem__(self, idx):
56
+ enc = self.tokenizer(
57
+ self.texts[idx],
58
+ truncation=True,
59
+ max_length=MAX_LENGTH,
60
+ padding="max_length",
61
+ return_tensors="pt",
62
+ )
63
+ return {
64
+ "input_ids": enc["input_ids"].squeeze(0),
65
+ "attention_mask": enc["attention_mask"].squeeze(0),
66
+ "labels": torch.tensor(self.labels[idx], dtype=torch.long),
67
+ }
68
+
69
+
70
+ def main():
71
+ if not DATA_PATH.exists():
72
+ print(f"ERROR: {DATA_PATH} not found. Run training/parse_poker.py first.")
73
+ return
74
+ with open(DATA_PATH) as f:
75
+ data = json.load(f)
76
+ texts = [x["text"] for x in data]
77
+ labels = [1 if x["is_bluff"] else 0 for x in data]
78
+
79
+ X_train, X_val, y_train, y_val = train_test_split(
80
+ texts, labels, test_size=0.2, stratify=labels, random_state=42
81
+ )
82
+
83
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
84
+ train_ds = BluffDataset(X_train, y_train, tokenizer)
85
+ val_ds = BluffDataset(X_val, y_val, tokenizer)
86
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
87
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
88
+
89
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
+ model = BluffClassifier().to(device)
91
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
92
+ criterion = nn.CrossEntropyLoss()
93
+
94
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
95
+
96
+ for epoch in range(EPOCHS):
97
+ model.train()
98
+ for batch in train_loader:
99
+ opt.zero_grad()
100
+ out = model(
101
+ input_ids=batch["input_ids"].to(device),
102
+ attention_mask=batch["attention_mask"].to(device),
103
+ )
104
+ loss = criterion(out, batch["labels"].to(device))
105
+ loss.backward()
106
+ opt.step()
107
+
108
+ model.eval()
109
+ correct, total = 0, 0
110
+ all_pred, all_true = [], []
111
+ with torch.no_grad():
112
+ for batch in val_loader:
113
+ out = model(
114
+ input_ids=batch["input_ids"].to(device),
115
+ attention_mask=batch["attention_mask"].to(device),
116
+ )
117
+ pred = out.argmax(dim=1)
118
+ correct += (pred == batch["labels"].to(device)).sum().item()
119
+ total += pred.size(0)
120
+ all_pred.extend(pred.cpu().tolist())
121
+ all_true.extend(batch["labels"].tolist())
122
+ acc = correct / total if total else 0
123
+
124
+ # F1 binary: bluff=1
125
+ tp = sum(1 for p, t in zip(all_pred, all_true) if p == 1 and t == 1)
126
+ fp = sum(1 for p, t in zip(all_pred, all_true) if p == 1 and t == 0)
127
+ fn = sum(1 for p, t in zip(all_pred, all_true) if p == 0 and t == 1)
128
+ prec = tp / (tp + fp) if (tp + fp) else 0
129
+ rec = tp / (tp + fn) if (tp + fn) else 0
130
+ f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0
131
+
132
+ print(f"Epoch {epoch + 1}/{EPOCHS} Val accuracy: {acc:.4f} Val F1: {f1:.4f}")
133
+
134
+ if acc < 0.65:
135
+ print(f"WARNING: Val accuracy {acc:.4f} < 0.65 (target). Consider more data or epochs.")
136
+ torch.save(model.state_dict(), MODEL_PT)
137
+ tokenizer.save_pretrained(TOKENIZER_DIR)
138
+ print(f"Saved model to {MODEL_PT}, tokenizer to {TOKENIZER_DIR}")
139
+
140
+
141
+ if __name__ == "__main__":
142
+ main()