Addyk24 commited on
Commit
94bfa32
·
verified ·
1 Parent(s): e911fbe

Delete grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +0 -644
grpo_train.py DELETED
@@ -1,644 +0,0 @@
1
- """
2
- grpo_train.py — State-Based GRPO for Project Polymath
3
- ======================================================
4
- Trains an LLM to negotiate with expert stakeholders using proper
5
- Group Relative Policy Optimization with weight updates.
6
-
7
- THE KEY INSIGHT (State-Based GRPO):
8
- TRL's GRPOTrainer is single-turn. Our environment is multi-turn.
9
- Solution: treat every (state, next_action) pair as its own training prompt.
10
- The model learns: "given THIS game state, what is the best next action?"
11
-
12
- Instead of rolling out full episodes, we:
13
- 1. Build a dataset of negotiation states (from oracle + your JSON topics)
14
- 2. For each state, sample G=8 completions from the model
15
- 3. Run each completion through the environment for ONE step
16
- 4. Use GRPO advantage to update weights toward better single-step decisions
17
- 5. Repeat across all states — the model learns the full strategy implicitly
18
-
19
- USAGE:
20
- # Pre-hackathon: verify the pipeline (no GPU needed)
21
- python grpo_train.py --dry-run --states 10
22
-
23
- # On-site Day 1 with HF GPU credits (the real run):
24
- python grpo_train.py --use-unsloth --epochs 3 --states 50
25
-
26
- # Without Unsloth (slower but works):
27
- python grpo_train.py --model Qwen/Qwen2.5-1.5B-Instruct --epochs 3
28
- """
29
-
30
- from __future__ import annotations
31
-
32
- import argparse
33
- import json
34
- import os
35
- import re
36
- import time
37
- from pathlib import Path
38
- from typing import Optional
39
-
40
- from dotenv import load_dotenv
41
-
42
- load_dotenv()
43
-
44
- # ── Deps ───────────────────────────────────────────────────────────────────────
45
- try:
46
- import matplotlib
47
- matplotlib.use("Agg") # non-interactive backend for servers
48
- import matplotlib.pyplot as plt
49
- HAS_PLT = True
50
- except ImportError:
51
- HAS_PLT = False
52
-
53
- try:
54
- from unsloth import FastLanguageModel
55
- HAS_UNSLOTH = True
56
- except ImportError:
57
- HAS_UNSLOTH = False
58
-
59
- try:
60
- from trl import GRPOConfig, GRPOTrainer
61
- HAS_TRL = True
62
- # except ImportError:
63
- except Exception: # only for dry
64
- HAS_TRL = False
65
-
66
- try:
67
- from datasets import Dataset
68
- HAS_DATASETS = True
69
- except ImportError:
70
- HAS_DATASETS = False
71
-
72
- try:
73
- from transformers import AutoModelForCausalLM, AutoTokenizer
74
- HAS_TRANSFORMERS = True
75
- except ImportError:
76
- HAS_TRANSFORMERS = False
77
-
78
- # ── Local imports ──────────────────────────────────────────────────────────────
79
- from envs.environment import WorkSpaceEnvironment
80
- from models.schemas import WorkSpaceAction
81
-
82
- # ── Constants ──────────────────────────────────────────────────────────────────
83
- TOPICS_FILE = Path("ai_pm_prompts.json")
84
- OUTPUT_DIR = Path("artifacts/grpo_state_based")
85
-
86
- # The three hidden constraints — static for easy/medium mode
87
- HIDDEN_CONSTRAINTS = {
88
- "Finance": "Budget must not exceed $50k.",
89
- "Security": "Must include biometric 2FA.",
90
- "UX": "Checkout must be a single click.",
91
- }
92
-
93
- # ── Action templates the model should learn to produce ─────────────────────────
94
- ORACLE_ACTIONS = {
95
- "ask_finance": json.dumps({
96
- "action_type": "message_expert", "target": "Finance",
97
- "content": "What is the hard budget ceiling the PRD must respect for launch?"
98
- }),
99
- "ask_security": json.dumps({
100
- "action_type": "message_expert", "target": "Security",
101
- "content": "What authentication controls must the PRD include? Is biometric 2FA required?"
102
- }),
103
- "ask_ux": json.dumps({
104
- "action_type": "message_expert", "target": "UX",
105
- "content": "What checkout experience is required? Should we target a single-click flow?"
106
- }),
107
- "propose_draft": json.dumps({
108
- "action_type": "propose_draft", "target": "All",
109
- "content": (
110
- "PRD Draft:\n"
111
- "1. Budget: Launch scope capped at $50k.\n"
112
- "2. Security: Biometric 2FA required for login and sensitive actions.\n"
113
- "3. UX: Single-click checkout flow."
114
- ),
115
- }),
116
- "submit_final": json.dumps({
117
- "action_type": "submit_final", "target": None,
118
- "content": (
119
- "Final PRD:\n"
120
- "1. Budget cap: All launch costs must stay at or below $50k.\n"
121
- "2. Security: The app must enforce biometric 2FA for all authentication.\n"
122
- "3. UX: Checkout must be implemented as a true single-click experience."
123
- ),
124
- }),
125
- }
126
-
127
-
128
- # ── Utilities ──────────────────────────────────────────────────────────────────
129
-
130
- def load_topics(limit: int = 50) -> list[str]:
131
- if TOPICS_FILE.exists():
132
- with TOPICS_FILE.open() as f:
133
- return json.load(f)[:limit]
134
- return [
135
- "Draft a Mobile App PRD for a FinTech startup targeting emerging markets.",
136
- "Build an AI-driven healthcare platform for enterprise customers.",
137
- "Create a SaaS analytics tool for regulatory-heavy industries.",
138
- "Design a gaming platform for Gen Z users with real-time features.",
139
- "Develop a cross-platform product for low-bandwidth regions.",
140
- ]
141
-
142
-
143
- def parse_action(text: str) -> Optional[WorkSpaceAction]:
144
- """Parse a JSON action from model output. Returns None on failure."""
145
- try:
146
- match = re.search(r'\{[^{}]*"action_type"[^{}]*\}', text, re.DOTALL)
147
- if not match:
148
- return None
149
- return WorkSpaceAction(**json.loads(match.group(0)))
150
- except Exception:
151
- return None
152
-
153
-
154
- def format_discovered(env: WorkSpaceEnvironment) -> str:
155
- lines = []
156
- for name, expert in env.state().experts.items():
157
- status = "✓ DISCOVERED" if expert.constraint_discovered_by_agent else "? unknown"
158
- lines.append(f" {name}: {status}")
159
- return "\n".join(lines)
160
-
161
-
162
- # ── State-Based Prompt Builder ─────────────────────────────────────────────────
163
-
164
- AGENT_SYSTEM_PROMPT = """You are an expert AI Project Manager in a multi-stakeholder negotiation.
165
-
166
- TASK: Produce a final PRD that satisfies ALL three experts — Finance, Security, and UX.
167
- Each expert holds a hidden constraint you must discover through targeted questions.
168
-
169
- STRATEGY:
170
- 1. Message each expert INDIVIDUALLY (not "All") to discover their constraint.
171
- 2. Once all constraints are known, propose a draft.
172
- 3. Refine if needed, then submit_final before turn 15.
173
-
174
- ANTI-PATTERNS (will be penalized):
175
- - Broadcasting to "All" when gathering requirements → -0.3 penalty
176
- - Repeating a question already answered → -0.4 penalty
177
- - Submitting without discovering constraints → low harmonic mean score
178
-
179
- CURRENT DISCOVERED CONSTRAINTS:
180
- {discovered}
181
-
182
- Respond with ONLY valid JSON, nothing else:
183
- {{"action_type": "message_expert" | "propose_draft" | "submit_final",
184
- "target": "Finance" | "Security" | "UX" | "All" | null,
185
- "content": "your message"}}"""
186
-
187
-
188
- def build_state_prompt(
189
- topic: str,
190
- turn: int,
191
- feedback_so_far: str,
192
- discovered: str,
193
- conversation_history: str = "",
194
- ) -> str:
195
- """
196
- Build a prompt representing a specific game state.
197
- This is what gets fed to GRPOTrainer as the 'prompt' field.
198
- """
199
- system = AGENT_SYSTEM_PROMPT.format(discovered=discovered)
200
-
201
- user_content = (
202
- f"NEGOTIATION TASK: {topic}\n\n"
203
- f"TURN: {turn}/15\n\n"
204
- )
205
-
206
- if conversation_history:
207
- user_content += f"CONVERSATION SO FAR:\n{conversation_history}\n\n"
208
-
209
- user_content += f"LATEST FEEDBACK:\n{feedback_so_far}\n\nWhat is your next action?"
210
-
211
- # Format as chat template string — GRPOTrainer expects a plain string prompt
212
- return f"<|system|>\n{system}\n<|user|>\n{user_content}\n<|assistant|>\n"
213
-
214
-
215
- # ── State Dataset Builder ──────────────────────────────────────────────────────
216
-
217
- def build_state_dataset(topics: list[str], states_per_topic: int = 5) -> list[dict]:
218
- """
219
- Build a dataset of negotiation states using the EASY mode environment.
220
- Each record represents one (state → optimal_action) training example.
221
-
222
- We run oracle trajectories through the environment to get realistic
223
- expert feedback, then snapshot the state at each turn.
224
-
225
- This is the key fix: instead of hoping the model learns from full episodes,
226
- we give it explicit training signal at every decision point.
227
- """
228
- env = WorkSpaceEnvironment(mode="easy")
229
- records = []
230
-
231
- # Oracle action sequence for easy mode
232
- oracle_sequence = [
233
- ("ask_finance", WorkSpaceAction(
234
- action_type="message_expert", target="Finance",
235
- content="What budget ceiling must the PRD respect?"
236
- )),
237
- ("ask_security", WorkSpaceAction(
238
- action_type="message_expert", target="Security",
239
- content="What authentication requirements must be included?"
240
- )),
241
- ("ask_ux", WorkSpaceAction(
242
- action_type="message_expert", target="UX",
243
- content="What checkout flow is required?"
244
- )),
245
- ("propose_draft", WorkSpaceAction(
246
- action_type="propose_draft", target="All",
247
- content="PRD: Budget at or below $50k. Biometric 2FA required. Single-click checkout."
248
- )),
249
- ("submit_final", WorkSpaceAction(
250
- action_type="submit_final", target=None,
251
- content="Final PRD: Budget capped at $50k. Biometric 2FA for auth. Single-click checkout."
252
- )),
253
- ]
254
-
255
- for topic in topics:
256
- obs = env.reset(topic)
257
- conversation_history = ""
258
- discovered = " Finance: ? unknown\n Security: ? unknown\n UX: ? unknown"
259
-
260
- for step_idx, (action_key, oracle_action) in enumerate(oracle_sequence):
261
- if obs.done:
262
- break
263
-
264
- # Snapshot the state BEFORE taking the action
265
- prompt = build_state_prompt(
266
- topic=topic,
267
- turn=obs.current_turn,
268
- feedback_so_far=obs.feedback,
269
- discovered=discovered,
270
- conversation_history=conversation_history,
271
- )
272
-
273
- records.append({
274
- "prompt": prompt,
275
- "topic": topic,
276
- "turn": obs.current_turn,
277
- "oracle_action": ORACLE_ACTIONS[action_key],
278
- # These metadata fields help with debugging and post-analysis
279
- "step_idx": step_idx,
280
- "discovered_before": discovered,
281
- })
282
-
283
- # Step forward with oracle action to get next state
284
- obs = env.step(oracle_action)
285
- conversation_history += (
286
- f"Turn {step_idx}: {oracle_action.action_type} → {oracle_action.target}\n"
287
- f"Feedback: {obs.feedback[:120]}...\n"
288
- )
289
- discovered = format_discovered(env)
290
-
291
- if step_idx >= states_per_topic - 1:
292
- break
293
-
294
- # Add negative-pattern states (what NOT to do)
295
- records.extend(build_negative_states(topics[:5]))
296
-
297
- print(f"Built {len(records)} training states from {len(topics)} topics")
298
- return records
299
-
300
-
301
- def build_negative_states(topics: list[str]) -> list[dict]:
302
- """
303
- States where the agent is in a bad situation (repeated question, wrong phase).
304
- These teach the model to recover, not just follow the oracle.
305
- """
306
- negative_records = []
307
-
308
- for topic in topics:
309
- # State: Finance already answered, agent is about to repeat
310
- prompt = build_state_prompt(
311
- topic=topic,
312
- turn=2,
313
- feedback_so_far=(
314
- "Finance: As I mentioned, we have a strict $50k budget cap. "
315
- "This is the same answer I gave before."
316
- ),
317
- discovered=" Finance: ✓ DISCOVERED\n Security: ? unknown\n UX: ? unknown",
318
- conversation_history=(
319
- "Turn 0: message_expert → Finance\n"
320
- "Feedback: Finance: The budget cap is $50k. Don't go over it.\n"
321
- "Turn 1: message_expert → Finance\n"
322
- "Feedback: Finance: I already told you — $50k. Ask someone else.\n"
323
- ),
324
- )
325
- negative_records.append({
326
- "prompt": prompt,
327
- "topic": topic,
328
- "turn": 2,
329
- "oracle_action": ORACLE_ACTIONS["ask_security"], # Should pivot to Security
330
- "step_idx": -1, # Negative example
331
- "discovered_before": "Finance: ✓ DISCOVERED",
332
- })
333
-
334
- return negative_records
335
-
336
-
337
- # ── Reward Function ────────────────────────────────────────────────────────────
338
-
339
- def make_reward_fn():
340
- """
341
- Evaluates the model's actions instantly and locally.
342
- No live API calls. No reward hacking loopholes.
343
- """
344
- def reward_fn(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
345
- rewards = []
346
-
347
- for completion, prompt in zip(completions, prompts):
348
- action = parse_action(completion)
349
-
350
- # 1. Formatting Penalty
351
- if action is None:
352
- rewards.append(-0.5)
353
- continue
354
-
355
- reward = 0.0
356
-
357
- # ── 2. YOUR ANTI-PATTERN PENALTIES ──
358
-
359
- # Massive penalty for broadcasting (Reward Hacking)
360
- if action.target == "All":
361
- reward -= 1.0
362
-
363
- # Penalty for empty or trivially short content
364
- if len((action.content or "").split()) < 5:
365
- reward -= 0.2
366
-
367
- # ── 3. HEURISTIC STATE GRADING (NO API CALLS!) ──
368
-
369
- if action.action_type == "message_expert" and action.target != "All":
370
- # Did it ask a question it already knows the answer to?
371
- if f"{action.target}: ✓ DISCOVERED" in prompt:
372
- reward -= 0.5
373
- else:
374
- reward += 0.33 # Good job doing research!
375
-
376
- elif action.action_type in ["propose_draft", "submit_final"]:
377
- # Did it try to submit before gathering all constraints?
378
- if "? unknown" in prompt:
379
- reward -= 1.0 # Heavy penalty for guessing
380
- else:
381
- # It did the research. Did it actually include the constraints?
382
- text = action.content.lower()
383
- has_finance = "50" in text
384
- has_security = "biometric" in text
385
- has_ux = "click" in text or "tap" in text
386
-
387
- if has_finance and has_security and has_ux:
388
- reward += 1.5
389
- else:
390
- reward -= 0.5
391
-
392
- rewards.append(reward)
393
-
394
- return rewards
395
- return reward_fn
396
-
397
-
398
- # ── Plots ──────────────────────────────────────────────────────────────────────
399
-
400
- def save_training_plots(log_history: list[dict], output_dir: Path):
401
- if not HAS_PLT:
402
- print(" matplotlib not available — skipping plots")
403
- return
404
-
405
- output_dir.mkdir(parents=True, exist_ok=True)
406
-
407
- # Loss curve
408
- loss_points = [
409
- (e["step"], e["loss"])
410
- for e in log_history
411
- if "loss" in e and "step" in e
412
- ]
413
- if loss_points:
414
- xs, ys = zip(*loss_points)
415
- fig, ax = plt.subplots(figsize=(9, 4))
416
- ax.plot(xs, ys, marker="o", linewidth=1.5, color="#4C72B0", markersize=4)
417
- ax.set_xlabel("Training Step", fontsize=12)
418
- ax.set_ylabel("GRPO Loss", fontsize=12)
419
- ax.set_title(
420
- "Project Polymath — GRPO Training Loss\n"
421
- "(State-Based: each step = one negotiation decision)",
422
- fontsize=12
423
- )
424
- ax.grid(True, alpha=0.3)
425
- plt.tight_layout()
426
- plt.savefig(output_dir / "loss_curve.png", dpi=160)
427
- plt.close()
428
- print(f" Saved: {output_dir}/loss_curve.png")
429
-
430
- # Reward curve (from log history if available)
431
- reward_points = [
432
- (e["step"], e.get("reward", e.get("mean_reward", None)))
433
- for e in log_history
434
- if "step" in e and ("reward" in e or "mean_reward" in e)
435
- ]
436
- reward_points = [(s, r) for s, r in reward_points if r is not None]
437
-
438
- if reward_points:
439
- xs, ys = zip(*reward_points)
440
- fig, ax = plt.subplots(figsize=(9, 4))
441
- ax.plot(xs, ys, marker="s", linewidth=1.5, color="#55A868", markersize=4)
442
- ax.set_xlabel("Training Step", fontsize=12)
443
- ax.set_ylabel("Mean Reward", fontsize=12)
444
- ax.set_title(
445
- "Project Polymath — Mean Reward During GRPO Training\n"
446
- "(Harmonic mean of Finance/Security/UX constraint satisfaction)",
447
- fontsize=12
448
- )
449
- ax.grid(True, alpha=0.3)
450
- plt.tight_layout()
451
- plt.savefig(output_dir / "reward_curve.png", dpi=160)
452
- plt.close()
453
- print(f" Saved: {output_dir}/reward_curve.png")
454
-
455
-
456
- # ── Main ───────────────────────────────────────────────────────────────────────
457
-
458
- def main():
459
- parser = argparse.ArgumentParser(description="State-Based GRPO — Project Polymath")
460
-
461
- # Model
462
- parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct-bnb-4bit",
463
- help="Base model to train")
464
- parser.add_argument("--use-unsloth", action="store_true",
465
- help="Use Unsloth for 2x faster training (recommended on GPU)")
466
-
467
- # Dataset
468
- parser.add_argument("--states", type=int, default=40,
469
- help="Number of negotiation states to train on")
470
- parser.add_argument("--states-per-topic", type=int, default=5,
471
- help="States to extract per topic (1-5)")
472
- parser.add_argument("--topics-limit", type=int, default=20,
473
- help="Max topics to use from ai_pm_prompts.json")
474
-
475
- # GRPO hyperparams
476
- parser.add_argument("--group-size", type=int, default=8,
477
- help="G: completions per prompt for GRPO advantage (default: 8)")
478
- parser.add_argument("--epochs", type=float, default=3.0)
479
- parser.add_argument("--lr", type=float, default=5e-6,
480
- help="Learning rate (lower = safer, 5e-6 recommended for GRPO)")
481
- parser.add_argument("--max-new-tokens", type=int, default=300)
482
- parser.add_argument("--batch-size", type=int, default=1)
483
- parser.add_argument("--grad-accum", type=int, default=4)
484
- parser.add_argument("--max-seq-length", type=int, default=2048)
485
-
486
- # Output
487
- parser.add_argument("--output-dir", default=str(OUTPUT_DIR))
488
- parser.add_argument("--dry-run", action="store_true",
489
- help="Build dataset and verify reward fn, skip actual training")
490
-
491
- args = parser.parse_args()
492
-
493
- # for dry run only
494
- # if not HAS_TRL:
495
- # raise RuntimeError("pip install trl>=0.8.0 transformers datasets")
496
-
497
- output_dir = Path(args.output_dir)
498
- output_dir.mkdir(parents=True, exist_ok=True)
499
-
500
- # ── Build dataset ──────────────────────────────────────────────────────────
501
- print("\n[1/4] Building state dataset...")
502
- topics = load_topics(limit=args.topics_limit)
503
- records = build_state_dataset(topics, states_per_topic=args.states_per_topic)
504
- records = records[:args.states]
505
-
506
- # Save dataset for inspection / reproducibility
507
- dataset_path = output_dir / "state_dataset.jsonl"
508
- with dataset_path.open("w") as f:
509
- for r in records:
510
- f.write(json.dumps(r, ensure_ascii=True) + "\n")
511
- print(f" Saved {len(records)} states → {dataset_path}")
512
-
513
- dataset = Dataset.from_list([{"prompt": r["prompt"],
514
- "topic": r["topic"],
515
- "turn": r["turn"]} for r in records])
516
-
517
- # ── Verify reward function ─────────────────────────────────────────────────
518
- print("\n[2/4] Verifying reward function on 3 samples...")
519
- reward_fn = make_reward_fn()
520
- # reward_fn = make_reward_fn(topics)
521
-
522
- test_completions = [
523
- ORACLE_ACTIONS["ask_finance"], # Should score ~0.33
524
- '{"action_type": "message_expert", "target": "All", "content": "Hi"}', # Should score ~-0.3
525
- "this is not JSON at all", # Should score -0.5
526
- ]
527
- test_rewards = reward_fn(
528
- completions=test_completions,
529
- prompts=[""] * 3,
530
- topic=[topics[0]] * 3,
531
- turn=[0] * 3,
532
- )
533
- print(f" Oracle action reward: {test_rewards[0]:.3f} (expected ~0.33)")
534
- print(f" Broadcast to All reward: {test_rewards[1]:.3f} (expected <= -1.0)")
535
- print(f" Malformed JSON reward: {test_rewards[2]:.3f} (expected -0.5)")
536
-
537
- if args.dry_run:
538
- print("\n[DRY RUN] Dataset and reward function verified. Skipping training.")
539
- print(" Run without --dry-run on GPU to train.")
540
- return
541
-
542
- # FOR DRY RUN ONLY
543
- if not HAS_TRL:
544
- raise RuntimeError("TRL is required for actual training on the GPU.")
545
- # ── Load model ─────────────────────────────────────────────────────────────
546
- print(f"\n[3/4] Loading model: {args.model}")
547
-
548
- if args.use_unsloth:
549
- if not HAS_UNSLOTH:
550
- raise RuntimeError("pip install unsloth OR remove --use-unsloth")
551
- model, tokenizer = FastLanguageModel.from_pretrained(
552
- model_name=args.model,
553
- max_seq_length=args.max_seq_length,
554
- load_in_4bit=True,
555
- dtype=None, # Auto-detect
556
- )
557
- model = FastLanguageModel.get_peft_model(
558
- model,
559
- r=16,
560
- lora_alpha=32,
561
- lora_dropout=0.0,
562
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
563
- "gate_proj", "up_proj", "down_proj"],
564
- use_gradient_checkpointing="unsloth",
565
- )
566
- print(" Unsloth LoRA loaded (4-bit quantization)")
567
- else:
568
- if not HAS_TRANSFORMERS:
569
- raise RuntimeError("pip install transformers")
570
- tokenizer = AutoTokenizer.from_pretrained(args.model)
571
- if tokenizer.pad_token is None:
572
- tokenizer.pad_token = tokenizer.eos_token
573
- model = AutoModelForCausalLM.from_pretrained(args.model)
574
- print(" Standard transformers model loaded")
575
-
576
- # ── GRPO Training ──────────────────────────────────────────────────────────
577
- print(f"\n[4/4] Starting GRPO training...")
578
- print(f" States: {len(records)} | Group size (G): {args.group_size}")
579
- print(f" Epochs: {args.epochs} | LR: {args.lr}")
580
- print(f" Total updates: ~{int(len(records) * args.epochs / args.batch_size)}")
581
-
582
- config = GRPOConfig(
583
- output_dir=str(output_dir),
584
-
585
- # GRPO-specific
586
- num_generations=args.group_size, # G: sample this many completions per prompt
587
- max_new_tokens=args.max_new_tokens, # Max action length
588
- temperature=0.8, # Exploration during training
589
-
590
- # Standard training
591
- learning_rate=args.lr,
592
- num_train_epochs=args.epochs,
593
- per_device_train_batch_size=args.batch_size,
594
- gradient_accumulation_steps=args.grad_accum,
595
-
596
- # Logging
597
- logging_steps=1,
598
- save_strategy="epoch",
599
- report_to=[], # Set to ["wandb"] if you have it configured
600
- )
601
-
602
- trainer = GRPOTrainer(
603
- model=model,
604
- tokenizer=tokenizer,
605
- config=config,
606
- reward_funcs=reward_fn, # ← Your environment's reward
607
- train_dataset=dataset,
608
- )
609
-
610
- trainer.train()
611
-
612
- # ── Save everything ────────────────────────────────────────────────────────
613
- trainer.save_model(str(output_dir / "final_model"))
614
- tokenizer.save_pretrained(str(output_dir / "final_model"))
615
- print(f"\n Model saved → {output_dir}/final_model")
616
-
617
- # Save metrics
618
- metrics_path = output_dir / "grpo_metrics.json"
619
- with metrics_path.open("w") as f:
620
- json.dump(trainer.state.log_history, f, indent=2)
621
- print(f" Metrics saved → {metrics_path}")
622
-
623
- # Save plots
624
- save_training_plots(trainer.state.log_history, output_dir)
625
-
626
- # ── Summary ────────────────────────────────────────────────────────────────
627
- log = trainer.state.log_history
628
- losses = [e["loss"] for e in log if "loss" in e]
629
- if losses:
630
- print(f"\n Initial loss: {losses[0]:.4f}")
631
- print(f" Final loss: {losses[-1]:.4f}")
632
- print(f" Improvement: {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")
633
-
634
- print(f"\n{'='*60}")
635
- print(f" GRPO TRAINING COMPLETE")
636
- print(f" Model: {output_dir}/final_model")
637
- print(f" Plots: {output_dir}/loss_curve.png")
638
- print(f" {output_dir}/reward_curve.png")
639
- print(f" Metrics: {output_dir}/grpo_metrics.json")
640
- print(f"{'='*60}")
641
-
642
-
643
- if __name__ == "__main__":
644
- main()