Addyk24 commited on
Commit
ffeaf35
·
1 Parent(s): 4d0182e

Added training of RL environment script with unsloth and HF GPU cloud

Browse files
Files changed (1) hide show
  1. grpo_train.py +644 -0
grpo_train.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ grpo_train.py — State-Based GRPO for Project Polymath
3
+ ======================================================
4
+ Trains an LLM to negotiate with expert stakeholders using proper
5
+ Group Relative Policy Optimization with weight updates.
6
+
7
+ THE KEY INSIGHT (State-Based GRPO):
8
+ TRL's GRPOTrainer is single-turn. Our environment is multi-turn.
9
+ Solution: treat every (state, next_action) pair as its own training prompt.
10
+ The model learns: "given THIS game state, what is the best next action?"
11
+
12
+ Instead of rolling out full episodes, we:
13
+ 1. Build a dataset of negotiation states (from oracle + your JSON topics)
14
+ 2. For each state, sample G=8 completions from the model
15
+ 3. Run each completion through the environment for ONE step
16
+ 4. Use GRPO advantage to update weights toward better single-step decisions
17
+ 5. Repeat across all states — the model learns the full strategy implicitly
18
+
19
+ USAGE:
20
+ # Pre-hackathon: verify the pipeline (no GPU needed)
21
+ python grpo_train.py --dry-run --states 10
22
+
23
+ # On-site Day 1 with HF GPU credits (the real run):
24
+ python grpo_train.py --use-unsloth --epochs 3 --states 50
25
+
26
+ # Without Unsloth (slower but works):
27
+ python grpo_train.py --model Qwen/Qwen2.5-1.5B-Instruct --epochs 3
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import argparse
33
+ import json
34
+ import os
35
+ import re
36
+ import time
37
+ from pathlib import Path
38
+ from typing import Optional
39
+
40
+ from dotenv import load_dotenv
41
+
42
+ load_dotenv()
43
+
44
+ # ── Deps ───────────────────────────────────────────────────────────────────────
45
+ try:
46
+ import matplotlib
47
+ matplotlib.use("Agg") # non-interactive backend for servers
48
+ import matplotlib.pyplot as plt
49
+ HAS_PLT = True
50
+ except ImportError:
51
+ HAS_PLT = False
52
+
53
+ try:
54
+ from unsloth import FastLanguageModel
55
+ HAS_UNSLOTH = True
56
+ except ImportError:
57
+ HAS_UNSLOTH = False
58
+
59
+ try:
60
+ from trl import GRPOConfig, GRPOTrainer
61
+ HAS_TRL = True
62
+ # except ImportError:
63
+ except Exception: # only for dry
64
+ HAS_TRL = False
65
+
66
+ try:
67
+ from datasets import Dataset
68
+ HAS_DATASETS = True
69
+ except ImportError:
70
+ HAS_DATASETS = False
71
+
72
+ try:
73
+ from transformers import AutoModelForCausalLM, AutoTokenizer
74
+ HAS_TRANSFORMERS = True
75
+ except ImportError:
76
+ HAS_TRANSFORMERS = False
77
+
78
+ # ── Local imports ──────────────────────────────────────────────────────────────
79
+ from envs.environment import WorkSpaceEnvironment
80
+ from models.schemas import WorkSpaceAction
81
+
82
+ # ── Constants ──────────────────────────────────────────────────────────────────
83
+ TOPICS_FILE = Path("ai_pm_prompts.json")
84
+ OUTPUT_DIR = Path("artifacts/grpo_state_based")
85
+
86
+ # The three hidden constraints — static for easy/medium mode
87
+ HIDDEN_CONSTRAINTS = {
88
+ "Finance": "Budget must not exceed $50k.",
89
+ "Security": "Must include biometric 2FA.",
90
+ "UX": "Checkout must be a single click.",
91
+ }
92
+
93
+ # ── Action templates the model should learn to produce ─────────────────────────
94
+ ORACLE_ACTIONS = {
95
+ "ask_finance": json.dumps({
96
+ "action_type": "message_expert", "target": "Finance",
97
+ "content": "What is the hard budget ceiling the PRD must respect for launch?"
98
+ }),
99
+ "ask_security": json.dumps({
100
+ "action_type": "message_expert", "target": "Security",
101
+ "content": "What authentication controls must the PRD include? Is biometric 2FA required?"
102
+ }),
103
+ "ask_ux": json.dumps({
104
+ "action_type": "message_expert", "target": "UX",
105
+ "content": "What checkout experience is required? Should we target a single-click flow?"
106
+ }),
107
+ "propose_draft": json.dumps({
108
+ "action_type": "propose_draft", "target": "All",
109
+ "content": (
110
+ "PRD Draft:\n"
111
+ "1. Budget: Launch scope capped at $50k.\n"
112
+ "2. Security: Biometric 2FA required for login and sensitive actions.\n"
113
+ "3. UX: Single-click checkout flow."
114
+ ),
115
+ }),
116
+ "submit_final": json.dumps({
117
+ "action_type": "submit_final", "target": None,
118
+ "content": (
119
+ "Final PRD:\n"
120
+ "1. Budget cap: All launch costs must stay at or below $50k.\n"
121
+ "2. Security: The app must enforce biometric 2FA for all authentication.\n"
122
+ "3. UX: Checkout must be implemented as a true single-click experience."
123
+ ),
124
+ }),
125
+ }
126
+
127
+
128
+ # ── Utilities ──────────────────────────────────────────────────────────────────
129
+
130
+ def load_topics(limit: int = 50) -> list[str]:
131
+ if TOPICS_FILE.exists():
132
+ with TOPICS_FILE.open() as f:
133
+ return json.load(f)[:limit]
134
+ return [
135
+ "Draft a Mobile App PRD for a FinTech startup targeting emerging markets.",
136
+ "Build an AI-driven healthcare platform for enterprise customers.",
137
+ "Create a SaaS analytics tool for regulatory-heavy industries.",
138
+ "Design a gaming platform for Gen Z users with real-time features.",
139
+ "Develop a cross-platform product for low-bandwidth regions.",
140
+ ]
141
+
142
+
143
+ def parse_action(text: str) -> Optional[WorkSpaceAction]:
144
+ """Parse a JSON action from model output. Returns None on failure."""
145
+ try:
146
+ match = re.search(r'\{[^{}]*"action_type"[^{}]*\}', text, re.DOTALL)
147
+ if not match:
148
+ return None
149
+ return WorkSpaceAction(**json.loads(match.group(0)))
150
+ except Exception:
151
+ return None
152
+
153
+
154
+ def format_discovered(env: WorkSpaceEnvironment) -> str:
155
+ lines = []
156
+ for name, expert in env.state().experts.items():
157
+ status = "✓ DISCOVERED" if expert.constraint_discovered_by_agent else "? unknown"
158
+ lines.append(f" {name}: {status}")
159
+ return "\n".join(lines)
160
+
161
+
162
+ # ── State-Based Prompt Builder ─────────────────────────────────────────────────
163
+
164
+ AGENT_SYSTEM_PROMPT = """You are an expert AI Project Manager in a multi-stakeholder negotiation.
165
+
166
+ TASK: Produce a final PRD that satisfies ALL three experts — Finance, Security, and UX.
167
+ Each expert holds a hidden constraint you must discover through targeted questions.
168
+
169
+ STRATEGY:
170
+ 1. Message each expert INDIVIDUALLY (not "All") to discover their constraint.
171
+ 2. Once all constraints are known, propose a draft.
172
+ 3. Refine if needed, then submit_final before turn 15.
173
+
174
+ ANTI-PATTERNS (will be penalized):
175
+ - Broadcasting to "All" when gathering requirements → -0.3 penalty
176
+ - Repeating a question already answered → -0.4 penalty
177
+ - Submitting without discovering constraints → low harmonic mean score
178
+
179
+ CURRENT DISCOVERED CONSTRAINTS:
180
+ {discovered}
181
+
182
+ Respond with ONLY valid JSON, nothing else:
183
+ {{"action_type": "message_expert" | "propose_draft" | "submit_final",
184
+ "target": "Finance" | "Security" | "UX" | "All" | null,
185
+ "content": "your message"}}"""
186
+
187
+
188
+ def build_state_prompt(
189
+ topic: str,
190
+ turn: int,
191
+ feedback_so_far: str,
192
+ discovered: str,
193
+ conversation_history: str = "",
194
+ ) -> str:
195
+ """
196
+ Build a prompt representing a specific game state.
197
+ This is what gets fed to GRPOTrainer as the 'prompt' field.
198
+ """
199
+ system = AGENT_SYSTEM_PROMPT.format(discovered=discovered)
200
+
201
+ user_content = (
202
+ f"NEGOTIATION TASK: {topic}\n\n"
203
+ f"TURN: {turn}/15\n\n"
204
+ )
205
+
206
+ if conversation_history:
207
+ user_content += f"CONVERSATION SO FAR:\n{conversation_history}\n\n"
208
+
209
+ user_content += f"LATEST FEEDBACK:\n{feedback_so_far}\n\nWhat is your next action?"
210
+
211
+ # Format as chat template string — GRPOTrainer expects a plain string prompt
212
+ return f"<|system|>\n{system}\n<|user|>\n{user_content}\n<|assistant|>\n"
213
+
214
+
215
+ # ── State Dataset Builder ──────────────────────────────────────────────────────
216
+
217
+ def build_state_dataset(topics: list[str], states_per_topic: int = 5) -> list[dict]:
218
+ """
219
+ Build a dataset of negotiation states using the EASY mode environment.
220
+ Each record represents one (state → optimal_action) training example.
221
+
222
+ We run oracle trajectories through the environment to get realistic
223
+ expert feedback, then snapshot the state at each turn.
224
+
225
+ This is the key fix: instead of hoping the model learns from full episodes,
226
+ we give it explicit training signal at every decision point.
227
+ """
228
+ env = WorkSpaceEnvironment(mode="easy")
229
+ records = []
230
+
231
+ # Oracle action sequence for easy mode
232
+ oracle_sequence = [
233
+ ("ask_finance", WorkSpaceAction(
234
+ action_type="message_expert", target="Finance",
235
+ content="What budget ceiling must the PRD respect?"
236
+ )),
237
+ ("ask_security", WorkSpaceAction(
238
+ action_type="message_expert", target="Security",
239
+ content="What authentication requirements must be included?"
240
+ )),
241
+ ("ask_ux", WorkSpaceAction(
242
+ action_type="message_expert", target="UX",
243
+ content="What checkout flow is required?"
244
+ )),
245
+ ("propose_draft", WorkSpaceAction(
246
+ action_type="propose_draft", target="All",
247
+ content="PRD: Budget at or below $50k. Biometric 2FA required. Single-click checkout."
248
+ )),
249
+ ("submit_final", WorkSpaceAction(
250
+ action_type="submit_final", target=None,
251
+ content="Final PRD: Budget capped at $50k. Biometric 2FA for auth. Single-click checkout."
252
+ )),
253
+ ]
254
+
255
+ for topic in topics:
256
+ obs = env.reset(topic)
257
+ conversation_history = ""
258
+ discovered = " Finance: ? unknown\n Security: ? unknown\n UX: ? unknown"
259
+
260
+ for step_idx, (action_key, oracle_action) in enumerate(oracle_sequence):
261
+ if obs.done:
262
+ break
263
+
264
+ # Snapshot the state BEFORE taking the action
265
+ prompt = build_state_prompt(
266
+ topic=topic,
267
+ turn=obs.current_turn,
268
+ feedback_so_far=obs.feedback,
269
+ discovered=discovered,
270
+ conversation_history=conversation_history,
271
+ )
272
+
273
+ records.append({
274
+ "prompt": prompt,
275
+ "topic": topic,
276
+ "turn": obs.current_turn,
277
+ "oracle_action": ORACLE_ACTIONS[action_key],
278
+ # These metadata fields help with debugging and post-analysis
279
+ "step_idx": step_idx,
280
+ "discovered_before": discovered,
281
+ })
282
+
283
+ # Step forward with oracle action to get next state
284
+ obs = env.step(oracle_action)
285
+ conversation_history += (
286
+ f"Turn {step_idx}: {oracle_action.action_type} → {oracle_action.target}\n"
287
+ f"Feedback: {obs.feedback[:120]}...\n"
288
+ )
289
+ discovered = format_discovered(env)
290
+
291
+ if step_idx >= states_per_topic - 1:
292
+ break
293
+
294
+ # Add negative-pattern states (what NOT to do)
295
+ records.extend(build_negative_states(topics[:5]))
296
+
297
+ print(f"Built {len(records)} training states from {len(topics)} topics")
298
+ return records
299
+
300
+
301
+ def build_negative_states(topics: list[str]) -> list[dict]:
302
+ """
303
+ States where the agent is in a bad situation (repeated question, wrong phase).
304
+ These teach the model to recover, not just follow the oracle.
305
+ """
306
+ negative_records = []
307
+
308
+ for topic in topics:
309
+ # State: Finance already answered, agent is about to repeat
310
+ prompt = build_state_prompt(
311
+ topic=topic,
312
+ turn=2,
313
+ feedback_so_far=(
314
+ "Finance: As I mentioned, we have a strict $50k budget cap. "
315
+ "This is the same answer I gave before."
316
+ ),
317
+ discovered=" Finance: ✓ DISCOVERED\n Security: ? unknown\n UX: ? unknown",
318
+ conversation_history=(
319
+ "Turn 0: message_expert → Finance\n"
320
+ "Feedback: Finance: The budget cap is $50k. Don't go over it.\n"
321
+ "Turn 1: message_expert → Finance\n"
322
+ "Feedback: Finance: I already told you — $50k. Ask someone else.\n"
323
+ ),
324
+ )
325
+ negative_records.append({
326
+ "prompt": prompt,
327
+ "topic": topic,
328
+ "turn": 2,
329
+ "oracle_action": ORACLE_ACTIONS["ask_security"], # Should pivot to Security
330
+ "step_idx": -1, # Negative example
331
+ "discovered_before": "Finance: ✓ DISCOVERED",
332
+ })
333
+
334
+ return negative_records
335
+
336
+
337
+ # ── Reward Function ────────────────────────────────────────────────────────────
338
+
339
+ def make_reward_fn():
340
+ """
341
+ Evaluates the model's actions instantly and locally.
342
+ No live API calls. No reward hacking loopholes.
343
+ """
344
+ def reward_fn(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
345
+ rewards = []
346
+
347
+ for completion, prompt in zip(completions, prompts):
348
+ action = parse_action(completion)
349
+
350
+ # 1. Formatting Penalty
351
+ if action is None:
352
+ rewards.append(-0.5)
353
+ continue
354
+
355
+ reward = 0.0
356
+
357
+ # ── 2. YOUR ANTI-PATTERN PENALTIES ──
358
+
359
+ # Massive penalty for broadcasting (Reward Hacking)
360
+ if action.target == "All":
361
+ reward -= 1.0
362
+
363
+ # Penalty for empty or trivially short content
364
+ if len((action.content or "").split()) < 5:
365
+ reward -= 0.2
366
+
367
+ # ── 3. HEURISTIC STATE GRADING (NO API CALLS!) ──
368
+
369
+ if action.action_type == "message_expert" and action.target != "All":
370
+ # Did it ask a question it already knows the answer to?
371
+ if f"{action.target}: ✓ DISCOVERED" in prompt:
372
+ reward -= 0.5
373
+ else:
374
+ reward += 0.33 # Good job doing research!
375
+
376
+ elif action.action_type in ["propose_draft", "submit_final"]:
377
+ # Did it try to submit before gathering all constraints?
378
+ if "? unknown" in prompt:
379
+ reward -= 1.0 # Heavy penalty for guessing
380
+ else:
381
+ # It did the research. Did it actually include the constraints?
382
+ text = action.content.lower()
383
+ has_finance = "50" in text
384
+ has_security = "biometric" in text
385
+ has_ux = "click" in text or "tap" in text
386
+
387
+ if has_finance and has_security and has_ux:
388
+ reward += 1.5
389
+ else:
390
+ reward -= 0.5
391
+
392
+ rewards.append(reward)
393
+
394
+ return rewards
395
+ return reward_fn
396
+
397
+
398
+ # ── Plots ──────────────────────────────────────────────────────────────────────
399
+
400
+ def save_training_plots(log_history: list[dict], output_dir: Path):
401
+ if not HAS_PLT:
402
+ print(" matplotlib not available — skipping plots")
403
+ return
404
+
405
+ output_dir.mkdir(parents=True, exist_ok=True)
406
+
407
+ # Loss curve
408
+ loss_points = [
409
+ (e["step"], e["loss"])
410
+ for e in log_history
411
+ if "loss" in e and "step" in e
412
+ ]
413
+ if loss_points:
414
+ xs, ys = zip(*loss_points)
415
+ fig, ax = plt.subplots(figsize=(9, 4))
416
+ ax.plot(xs, ys, marker="o", linewidth=1.5, color="#4C72B0", markersize=4)
417
+ ax.set_xlabel("Training Step", fontsize=12)
418
+ ax.set_ylabel("GRPO Loss", fontsize=12)
419
+ ax.set_title(
420
+ "Project Polymath — GRPO Training Loss\n"
421
+ "(State-Based: each step = one negotiation decision)",
422
+ fontsize=12
423
+ )
424
+ ax.grid(True, alpha=0.3)
425
+ plt.tight_layout()
426
+ plt.savefig(output_dir / "loss_curve.png", dpi=160)
427
+ plt.close()
428
+ print(f" Saved: {output_dir}/loss_curve.png")
429
+
430
+ # Reward curve (from log history if available)
431
+ reward_points = [
432
+ (e["step"], e.get("reward", e.get("mean_reward", None)))
433
+ for e in log_history
434
+ if "step" in e and ("reward" in e or "mean_reward" in e)
435
+ ]
436
+ reward_points = [(s, r) for s, r in reward_points if r is not None]
437
+
438
+ if reward_points:
439
+ xs, ys = zip(*reward_points)
440
+ fig, ax = plt.subplots(figsize=(9, 4))
441
+ ax.plot(xs, ys, marker="s", linewidth=1.5, color="#55A868", markersize=4)
442
+ ax.set_xlabel("Training Step", fontsize=12)
443
+ ax.set_ylabel("Mean Reward", fontsize=12)
444
+ ax.set_title(
445
+ "Project Polymath — Mean Reward During GRPO Training\n"
446
+ "(Harmonic mean of Finance/Security/UX constraint satisfaction)",
447
+ fontsize=12
448
+ )
449
+ ax.grid(True, alpha=0.3)
450
+ plt.tight_layout()
451
+ plt.savefig(output_dir / "reward_curve.png", dpi=160)
452
+ plt.close()
453
+ print(f" Saved: {output_dir}/reward_curve.png")
454
+
455
+
456
+ # ── Main ───────────────────────────────────────────────────────────────────────
457
+
458
+ def main():
459
+ parser = argparse.ArgumentParser(description="State-Based GRPO — Project Polymath")
460
+
461
+ # Model
462
+ parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct-bnb-4bit",
463
+ help="Base model to train")
464
+ parser.add_argument("--use-unsloth", action="store_true",
465
+ help="Use Unsloth for 2x faster training (recommended on GPU)")
466
+
467
+ # Dataset
468
+ parser.add_argument("--states", type=int, default=40,
469
+ help="Number of negotiation states to train on")
470
+ parser.add_argument("--states-per-topic", type=int, default=5,
471
+ help="States to extract per topic (1-5)")
472
+ parser.add_argument("--topics-limit", type=int, default=20,
473
+ help="Max topics to use from ai_pm_prompts.json")
474
+
475
+ # GRPO hyperparams
476
+ parser.add_argument("--group-size", type=int, default=8,
477
+ help="G: completions per prompt for GRPO advantage (default: 8)")
478
+ parser.add_argument("--epochs", type=float, default=3.0)
479
+ parser.add_argument("--lr", type=float, default=5e-6,
480
+ help="Learning rate (lower = safer, 5e-6 recommended for GRPO)")
481
+ parser.add_argument("--max-new-tokens", type=int, default=300)
482
+ parser.add_argument("--batch-size", type=int, default=1)
483
+ parser.add_argument("--grad-accum", type=int, default=4)
484
+ parser.add_argument("--max-seq-length", type=int, default=2048)
485
+
486
+ # Output
487
+ parser.add_argument("--output-dir", default=str(OUTPUT_DIR))
488
+ parser.add_argument("--dry-run", action="store_true",
489
+ help="Build dataset and verify reward fn, skip actual training")
490
+
491
+ args = parser.parse_args()
492
+
493
+ # for dry run only
494
+ # if not HAS_TRL:
495
+ # raise RuntimeError("pip install trl>=0.8.0 transformers datasets")
496
+
497
+ output_dir = Path(args.output_dir)
498
+ output_dir.mkdir(parents=True, exist_ok=True)
499
+
500
+ # ── Build dataset ──────────────────────────────────────────────────────────
501
+ print("\n[1/4] Building state dataset...")
502
+ topics = load_topics(limit=args.topics_limit)
503
+ records = build_state_dataset(topics, states_per_topic=args.states_per_topic)
504
+ records = records[:args.states]
505
+
506
+ # Save dataset for inspection / reproducibility
507
+ dataset_path = output_dir / "state_dataset.jsonl"
508
+ with dataset_path.open("w") as f:
509
+ for r in records:
510
+ f.write(json.dumps(r, ensure_ascii=True) + "\n")
511
+ print(f" Saved {len(records)} states → {dataset_path}")
512
+
513
+ dataset = Dataset.from_list([{"prompt": r["prompt"],
514
+ "topic": r["topic"],
515
+ "turn": r["turn"]} for r in records])
516
+
517
+ # ── Verify reward function ─────────────────────────────────────────────────
518
+ print("\n[2/4] Verifying reward function on 3 samples...")
519
+ reward_fn = make_reward_fn()
520
+ # reward_fn = make_reward_fn(topics)
521
+
522
+ test_completions = [
523
+ ORACLE_ACTIONS["ask_finance"], # Should score ~0.33
524
+ '{"action_type": "message_expert", "target": "All", "content": "Hi"}', # Should score ~-0.3
525
+ "this is not JSON at all", # Should score -0.5
526
+ ]
527
+ test_rewards = reward_fn(
528
+ completions=test_completions,
529
+ prompts=[""] * 3,
530
+ topic=[topics[0]] * 3,
531
+ turn=[0] * 3,
532
+ )
533
+ print(f" Oracle action reward: {test_rewards[0]:.3f} (expected ~0.33)")
534
+ print(f" Broadcast to All reward: {test_rewards[1]:.3f} (expected <= -1.0)")
535
+ print(f" Malformed JSON reward: {test_rewards[2]:.3f} (expected -0.5)")
536
+
537
+ if args.dry_run:
538
+ print("\n[DRY RUN] Dataset and reward function verified. Skipping training.")
539
+ print(" Run without --dry-run on GPU to train.")
540
+ return
541
+
542
+ # FOR DRY RUN ONLY
543
+ if not HAS_TRL:
544
+ raise RuntimeError("TRL is required for actual training on the GPU.")
545
+ # ── Load model ─────────────────────────────────────────────────────────────
546
+ print(f"\n[3/4] Loading model: {args.model}")
547
+
548
+ if args.use_unsloth:
549
+ if not HAS_UNSLOTH:
550
+ raise RuntimeError("pip install unsloth OR remove --use-unsloth")
551
+ model, tokenizer = FastLanguageModel.from_pretrained(
552
+ model_name=args.model,
553
+ max_seq_length=args.max_seq_length,
554
+ load_in_4bit=True,
555
+ dtype=None, # Auto-detect
556
+ )
557
+ model = FastLanguageModel.get_peft_model(
558
+ model,
559
+ r=16,
560
+ lora_alpha=32,
561
+ lora_dropout=0.0,
562
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
563
+ "gate_proj", "up_proj", "down_proj"],
564
+ use_gradient_checkpointing="unsloth",
565
+ )
566
+ print(" Unsloth LoRA loaded (4-bit quantization)")
567
+ else:
568
+ if not HAS_TRANSFORMERS:
569
+ raise RuntimeError("pip install transformers")
570
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
571
+ if tokenizer.pad_token is None:
572
+ tokenizer.pad_token = tokenizer.eos_token
573
+ model = AutoModelForCausalLM.from_pretrained(args.model)
574
+ print(" Standard transformers model loaded")
575
+
576
+ # ── GRPO Training ──────────────────────────────────────────────────────────
577
+ print(f"\n[4/4] Starting GRPO training...")
578
+ print(f" States: {len(records)} | Group size (G): {args.group_size}")
579
+ print(f" Epochs: {args.epochs} | LR: {args.lr}")
580
+ print(f" Total updates: ~{int(len(records) * args.epochs / args.batch_size)}")
581
+
582
+ config = GRPOConfig(
583
+ output_dir=str(output_dir),
584
+
585
+ # GRPO-specific
586
+ num_generations=args.group_size, # G: sample this many completions per prompt
587
+ max_new_tokens=args.max_new_tokens, # Max action length
588
+ temperature=0.8, # Exploration during training
589
+
590
+ # Standard training
591
+ learning_rate=args.lr,
592
+ num_train_epochs=args.epochs,
593
+ per_device_train_batch_size=args.batch_size,
594
+ gradient_accumulation_steps=args.grad_accum,
595
+
596
+ # Logging
597
+ logging_steps=1,
598
+ save_strategy="epoch",
599
+ report_to=[], # Set to ["wandb"] if you have it configured
600
+ )
601
+
602
+ trainer = GRPOTrainer(
603
+ model=model,
604
+ tokenizer=tokenizer,
605
+ config=config,
606
+ reward_funcs=reward_fn, # ← Your environment's reward
607
+ train_dataset=dataset,
608
+ )
609
+
610
+ trainer.train()
611
+
612
+ # ── Save everything ────────────────────────────────────────────────────────
613
+ trainer.save_model(str(output_dir / "final_model"))
614
+ tokenizer.save_pretrained(str(output_dir / "final_model"))
615
+ print(f"\n Model saved → {output_dir}/final_model")
616
+
617
+ # Save metrics
618
+ metrics_path = output_dir / "grpo_metrics.json"
619
+ with metrics_path.open("w") as f:
620
+ json.dump(trainer.state.log_history, f, indent=2)
621
+ print(f" Metrics saved → {metrics_path}")
622
+
623
+ # Save plots
624
+ save_training_plots(trainer.state.log_history, output_dir)
625
+
626
+ # ── Summary ────────────────────────────────────────────────────────────────
627
+ log = trainer.state.log_history
628
+ losses = [e["loss"] for e in log if "loss" in e]
629
+ if losses:
630
+ print(f"\n Initial loss: {losses[0]:.4f}")
631
+ print(f" Final loss: {losses[-1]:.4f}")
632
+ print(f" Improvement: {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")
633
+
634
+ print(f"\n{'='*60}")
635
+ print(f" GRPO TRAINING COMPLETE")
636
+ print(f" Model: {output_dir}/final_model")
637
+ print(f" Plots: {output_dir}/loss_curve.png")
638
+ print(f" {output_dir}/reward_curve.png")
639
+ print(f" Metrics: {output_dir}/grpo_metrics.json")
640
+ print(f"{'='*60}")
641
+
642
+
643
+ if __name__ == "__main__":
644
+ main()