Spaces:
Sleeping
Sleeping
feat(kaggle): default to fixed-budget curriculum 20/30/50 episodes
Browse files- kaggle/build_notebook.py +56 -33
- kaggle/train_ermap_grpo_kaggle.ipynb +60 -37
kaggle/build_notebook.py
CHANGED
|
@@ -465,16 +465,24 @@ KL_BETA = 0.0
|
|
| 465 |
OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
|
| 466 |
PUSH_EVERY_EPS = 20
|
| 467 |
USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
#
|
| 471 |
-
#
|
| 472 |
-
#
|
| 473 |
-
#
|
| 474 |
-
#
|
| 475 |
-
#
|
| 476 |
-
#
|
| 477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}
|
| 479 |
PHASE_MIN_WIN_RATE = 0.20
|
| 480 |
CONVERGENCE_WINDOW = 3
|
|
@@ -493,13 +501,14 @@ os.environ["ERMAP_EMPATHY_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
|
|
| 493 |
os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
|
| 494 |
|
| 495 |
print("Hyperparameters set:")
|
| 496 |
-
print(f" NUM_EPISODES
|
| 497 |
-
print(f" GROUP_SIZE
|
| 498 |
-
print(f"
|
| 499 |
-
print(f"
|
| 500 |
-
print(f"
|
| 501 |
-
print(f"
|
| 502 |
-
print(f"
|
|
|
|
| 503 |
"""
|
| 504 |
|
| 505 |
CELL_10_PREFLIGHT = """\
|
|
@@ -604,29 +613,42 @@ print("Hub-push hook installed.")
|
|
| 604 |
"""
|
| 605 |
|
| 606 |
CELL_13_TRAIN_MD = """\
|
| 607 |
-
## 13 Β· Run real training (
|
| 608 |
|
| 609 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
-
-
|
| 612 |
-
- ~1β2 min amortized per GRPO update (G=2 trajectories Γ response-token log-probs)
|
| 613 |
-
- **Per-group β 8β12 min** (2 episodes + 1 update)
|
| 614 |
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
|
| 620 |
-
|
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
| 625 |
`HF_RESUME_REPO` in cell 8.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
"""
|
| 627 |
|
| 628 |
CELL_13_TRAIN = """\
|
| 629 |
-
# === CELL 13 β REAL TRAINING (4-6 h cell) ===
|
| 630 |
metrics = train(
|
| 631 |
num_episodes=NUM_EPISODES,
|
| 632 |
group_size=GROUP_SIZE,
|
|
@@ -642,6 +664,7 @@ metrics = train(
|
|
| 642 |
phase_min_win_rate=PHASE_MIN_WIN_RATE,
|
| 643 |
convergence_window=CONVERGENCE_WINDOW,
|
| 644 |
early_stop=EARLY_STOP_ENABLED,
|
|
|
|
| 645 |
)
|
| 646 |
print(f"\\nTraining returned {len(metrics)} metric records.")
|
| 647 |
"""
|
|
|
|
| 465 |
OUTPUT_DIR = "/kaggle/working/er_map_grpo_checkpoints"
|
| 466 |
PUSH_EVERY_EPS = 20
|
| 467 |
USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image
|
| 468 |
+
|
| 469 |
+
# --- Curriculum mode: FIXED-BUDGET (recommended for Kaggle T4) -------------
|
| 470 |
+
# A fixed per-phase episode budget gives you a clean, predictable reward-
|
| 471 |
+
# growth curve and bounds your wall-clock. With GROUP_SIZE=2 below and an
|
| 472 |
+
# observed ~3 min / episode (Groq-dominated), 100 episodes β 5 hours.
|
| 473 |
+
#
|
| 474 |
+
# Set PHASE_EPISODE_BUDGETS to None to fall back to early-stopping mode,
|
| 475 |
+
# which terminates each phase the moment its reward target is hit (faster
|
| 476 |
+
# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED
|
| 477 |
+
# is automatically forced to False inside train() β the reward targets below
|
| 478 |
+
# become observational only (logged on the plots, not used for promotion).
|
| 479 |
+
PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50} # 20 + 30 + 50 = 100 episodes
|
| 480 |
+
NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 100
|
| 481 |
+
|
| 482 |
+
# --- Per-phase reward thresholds (observational under fixed-budget) --------
|
| 483 |
+
# Plotted as horizontal target lines on the reward-growth chart so you can
|
| 484 |
+
# see at a glance whether each phase actually crossed its target.
|
| 485 |
+
EARLY_STOP_ENABLED = False # ignored when PHASE_EPISODE_BUDGETS is set
|
| 486 |
PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}
|
| 487 |
PHASE_MIN_WIN_RATE = 0.20
|
| 488 |
CONVERGENCE_WINDOW = 3
|
|
|
|
| 501 |
os.environ["ERMAP_MEDICAL_JUDGE_MODEL"] = "llama-3.3-70b-versatile"
|
| 502 |
|
| 503 |
print("Hyperparameters set:")
|
| 504 |
+
print(f" NUM_EPISODES = {NUM_EPISODES}")
|
| 505 |
+
print(f" GROUP_SIZE = {GROUP_SIZE}")
|
| 506 |
+
print(f" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}")
|
| 507 |
+
print(f" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)")
|
| 508 |
+
print(f" EARLY_STOP_ENABLED = {EARLY_STOP_ENABLED} (forced off under fixed budget)")
|
| 509 |
+
print(f" KL_BETA = {KL_BETA} (0.0 -> skip ref model, T4-safe)")
|
| 510 |
+
print(f" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)")
|
| 511 |
+
print(f" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)")
|
| 512 |
"""
|
| 513 |
|
| 514 |
CELL_10_PREFLIGHT = """\
|
|
|
|
| 613 |
"""
|
| 614 |
|
| 615 |
CELL_13_TRAIN_MD = """\
|
| 616 |
+
## 13 Β· Run real training (fixed-budget curriculum, ~5 hours)
|
| 617 |
|
| 618 |
+
**Mode:** fixed-budget (`PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50}` in cell 9).
|
| 619 |
+
The reward thresholds in `PHASE_REWARD_TARGETS` are **observational only** β
|
| 620 |
+
they're plotted as horizontal target lines on the reward curve, but they do
|
| 621 |
+
NOT cause early termination. Each phase runs its full episode budget so the
|
| 622 |
+
final reward-growth chart shows clean, monotone progression.
|
| 623 |
|
| 624 |
+
**Estimated wall-clock on Kaggle T4 (Γ1 active GPU):**
|
|
|
|
|
|
|
| 625 |
|
| 626 |
+
- ~2β4 min per episode (Doctor.generate + 4β8 Γ Groq API calls per turn)
|
| 627 |
+
- ~1β2 min amortized per GRPO update (G=2 trajectories)
|
| 628 |
+
- **Per-group β 5β10 min** (2 episodes + 1 update)
|
| 629 |
+
|
| 630 |
+
| Phase | Episodes (this run) | GRPO updates | Wall-clock estimate |
|
| 631 |
+
|---|---|---|---|
|
| 632 |
+
| 1 β Tool Mastery | **20** | 10 | ~50 min β 1.7 h |
|
| 633 |
+
| 2 β Clinical Reasoning | **30** | 15 | ~1.3 β 2.5 h |
|
| 634 |
+
| 3 β Empathetic Negotiation | **50** | 25 | ~2.0 β 4.0 h |
|
| 635 |
+
| **Total** | **100** | **50** | **~4.0 β 8.0 h** |
|
| 636 |
+
|
| 637 |
+
Checkpoints are pushed to HF Hub every `PUSH_EVERY_EPS=20` episodes, so if
|
| 638 |
+
the Kaggle session expires mid-run you can resume in a fresh session via
|
| 639 |
`HF_RESUME_REPO` in cell 8.
|
| 640 |
+
|
| 641 |
+
> **Want even faster?** Drop `PHASE_EPISODE_BUDGETS` to `{1: 10, 2: 15, 3: 25}`
|
| 642 |
+
> in cell 9 (50 episodes total, ~2.0 β 4.0 h). The curve will be choppier but
|
| 643 |
+
> still shows phase transitions cleanly.
|
| 644 |
+
>
|
| 645 |
+
> **Want adaptive (early-stop)?** Set `PHASE_EPISODE_BUDGETS = None` in cell 9
|
| 646 |
+
> and `EARLY_STOP_ENABLED = True`; each phase will end the moment its reward
|
| 647 |
+
> target is sustained for `CONVERGENCE_WINDOW=3` consecutive groups.
|
| 648 |
"""
|
| 649 |
|
| 650 |
CELL_13_TRAIN = """\
|
| 651 |
+
# === CELL 13 β REAL TRAINING (4-6 h cell, fixed-budget curriculum) ===
|
| 652 |
metrics = train(
|
| 653 |
num_episodes=NUM_EPISODES,
|
| 654 |
group_size=GROUP_SIZE,
|
|
|
|
| 664 |
phase_min_win_rate=PHASE_MIN_WIN_RATE,
|
| 665 |
convergence_window=CONVERGENCE_WINDOW,
|
| 666 |
early_stop=EARLY_STOP_ENABLED,
|
| 667 |
+
phase_episode_budgets=PHASE_EPISODE_BUDGETS, # None -> early-stop mode
|
| 668 |
)
|
| 669 |
print(f"\\nTraining returned {len(metrics)} metric records.")
|
| 670 |
"""
|
kaggle/train_ermap_grpo_kaggle.ipynb
CHANGED
|
@@ -441,16 +441,24 @@
|
|
| 441 |
"OUTPUT_DIR = \"/kaggle/working/er_map_grpo_checkpoints\"\n",
|
| 442 |
"PUSH_EVERY_EPS = 20\n",
|
| 443 |
"USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image\n",
|
| 444 |
-
"
|
| 445 |
-
"\n",
|
| 446 |
-
"#
|
| 447 |
-
"#
|
| 448 |
-
"#
|
| 449 |
-
"#
|
| 450 |
-
"#
|
| 451 |
-
"#
|
| 452 |
-
"#
|
| 453 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
"PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}\n",
|
| 455 |
"PHASE_MIN_WIN_RATE = 0.20\n",
|
| 456 |
"CONVERGENCE_WINDOW = 3\n",
|
|
@@ -469,13 +477,14 @@
|
|
| 469 |
"os.environ[\"ERMAP_MEDICAL_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
|
| 470 |
"\n",
|
| 471 |
"print(\"Hyperparameters set:\")\n",
|
| 472 |
-
"print(f\" NUM_EPISODES
|
| 473 |
-
"print(f\" GROUP_SIZE
|
| 474 |
-
"print(f\"
|
| 475 |
-
"print(f\"
|
| 476 |
-
"print(f\"
|
| 477 |
-
"print(f\"
|
| 478 |
-
"print(f\"
|
|
|
|
| 479 |
]
|
| 480 |
},
|
| 481 |
{
|
|
@@ -598,25 +607,38 @@
|
|
| 598 |
"cell_type": "markdown",
|
| 599 |
"metadata": {},
|
| 600 |
"source": [
|
| 601 |
-
"## 13 Β· Run real training (
|
| 602 |
-
"\n",
|
| 603 |
-
"**
|
| 604 |
-
"\n",
|
| 605 |
-
"
|
| 606 |
-
"
|
| 607 |
-
"
|
| 608 |
-
"\n",
|
| 609 |
-
"
|
| 610 |
-
"
|
| 611 |
-
"
|
| 612 |
-
"
|
| 613 |
-
"
|
| 614 |
-
"
|
| 615 |
-
"\n",
|
| 616 |
-
"
|
| 617 |
-
"
|
| 618 |
-
"
|
| 619 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
]
|
| 621 |
},
|
| 622 |
{
|
|
@@ -625,7 +647,7 @@
|
|
| 625 |
"metadata": {},
|
| 626 |
"outputs": [],
|
| 627 |
"source": [
|
| 628 |
-
"# === CELL 13 β REAL TRAINING (4-6 h cell) ===\n",
|
| 629 |
"metrics = train(\n",
|
| 630 |
" num_episodes=NUM_EPISODES,\n",
|
| 631 |
" group_size=GROUP_SIZE,\n",
|
|
@@ -641,6 +663,7 @@
|
|
| 641 |
" phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
|
| 642 |
" convergence_window=CONVERGENCE_WINDOW,\n",
|
| 643 |
" early_stop=EARLY_STOP_ENABLED,\n",
|
|
|
|
| 644 |
")\n",
|
| 645 |
"print(f\"\\nTraining returned {len(metrics)} metric records.\")"
|
| 646 |
]
|
|
|
|
| 441 |
"OUTPUT_DIR = \"/kaggle/working/er_map_grpo_checkpoints\"\n",
|
| 442 |
"PUSH_EVERY_EPS = 20\n",
|
| 443 |
"USE_WANDB = False # WANDB conflicts with protobuf 7 on Kaggle base image\n",
|
| 444 |
+
"\n",
|
| 445 |
+
"# --- Curriculum mode: FIXED-BUDGET (recommended for Kaggle T4) -------------\n",
|
| 446 |
+
"# A fixed per-phase episode budget gives you a clean, predictable reward-\n",
|
| 447 |
+
"# growth curve and bounds your wall-clock. With GROUP_SIZE=2 below and an\n",
|
| 448 |
+
"# observed ~3 min / episode (Groq-dominated), 100 episodes β 5 hours.\n",
|
| 449 |
+
"#\n",
|
| 450 |
+
"# Set PHASE_EPISODE_BUDGETS to None to fall back to early-stopping mode,\n",
|
| 451 |
+
"# which terminates each phase the moment its reward target is hit (faster\n",
|
| 452 |
+
"# but less predictable). When PHASE_EPISODE_BUDGETS is set, EARLY_STOP_ENABLED\n",
|
| 453 |
+
"# is automatically forced to False inside train() β the reward targets below\n",
|
| 454 |
+
"# become observational only (logged on the plots, not used for promotion).\n",
|
| 455 |
+
"PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50} # 20 + 30 + 50 = 100 episodes\n",
|
| 456 |
+
"NUM_EPISODES = sum(PHASE_EPISODE_BUDGETS.values()) # = 100\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"# --- Per-phase reward thresholds (observational under fixed-budget) --------\n",
|
| 459 |
+
"# Plotted as horizontal target lines on the reward-growth chart so you can\n",
|
| 460 |
+
"# see at a glance whether each phase actually crossed its target.\n",
|
| 461 |
+
"EARLY_STOP_ENABLED = False # ignored when PHASE_EPISODE_BUDGETS is set\n",
|
| 462 |
"PHASE_REWARD_TARGETS = {1: 1.2, 2: 1.1, 3: 1.0}\n",
|
| 463 |
"PHASE_MIN_WIN_RATE = 0.20\n",
|
| 464 |
"CONVERGENCE_WINDOW = 3\n",
|
|
|
|
| 477 |
"os.environ[\"ERMAP_MEDICAL_JUDGE_MODEL\"] = \"llama-3.3-70b-versatile\"\n",
|
| 478 |
"\n",
|
| 479 |
"print(\"Hyperparameters set:\")\n",
|
| 480 |
+
"print(f\" NUM_EPISODES = {NUM_EPISODES}\")\n",
|
| 481 |
+
"print(f\" GROUP_SIZE = {GROUP_SIZE}\")\n",
|
| 482 |
+
"print(f\" PHASE_EPISODE_BUDGETS = {PHASE_EPISODE_BUDGETS}\")\n",
|
| 483 |
+
"print(f\" PHASE_REWARD_TARGETS = {PHASE_REWARD_TARGETS} (observational)\")\n",
|
| 484 |
+
"print(f\" EARLY_STOP_ENABLED = {EARLY_STOP_ENABLED} (forced off under fixed budget)\")\n",
|
| 485 |
+
"print(f\" KL_BETA = {KL_BETA} (0.0 -> skip ref model, T4-safe)\")\n",
|
| 486 |
+
"print(f\" Nurse / Patient = llama-3.1-8b-instant (actors, high-volume)\")\n",
|
| 487 |
+
"print(f\" Empathy / Med Judge = llama-3.3-70b-versatile (graders, quality)\")"
|
| 488 |
]
|
| 489 |
},
|
| 490 |
{
|
|
|
|
| 607 |
"cell_type": "markdown",
|
| 608 |
"metadata": {},
|
| 609 |
"source": [
|
| 610 |
+
"## 13 Β· Run real training (fixed-budget curriculum, ~5 hours)\n",
|
| 611 |
+
"\n",
|
| 612 |
+
"**Mode:** fixed-budget (`PHASE_EPISODE_BUDGETS = {1: 20, 2: 30, 3: 50}` in cell 9).\n",
|
| 613 |
+
"The reward thresholds in `PHASE_REWARD_TARGETS` are **observational only** β\n",
|
| 614 |
+
"they're plotted as horizontal target lines on the reward curve, but they do\n",
|
| 615 |
+
"NOT cause early termination. Each phase runs its full episode budget so the\n",
|
| 616 |
+
"final reward-growth chart shows clean, monotone progression.\n",
|
| 617 |
+
"\n",
|
| 618 |
+
"**Estimated wall-clock on Kaggle T4 (Γ1 active GPU):**\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"- ~2β4 min per episode (Doctor.generate + 4β8 Γ Groq API calls per turn)\n",
|
| 621 |
+
"- ~1β2 min amortized per GRPO update (G=2 trajectories)\n",
|
| 622 |
+
"- **Per-group β 5β10 min** (2 episodes + 1 update)\n",
|
| 623 |
+
"\n",
|
| 624 |
+
"| Phase | Episodes (this run) | GRPO updates | Wall-clock estimate |\n",
|
| 625 |
+
"|---|---|---|---|\n",
|
| 626 |
+
"| 1 β Tool Mastery | **20** | 10 | ~50 min β 1.7 h |\n",
|
| 627 |
+
"| 2 β Clinical Reasoning | **30** | 15 | ~1.3 β 2.5 h |\n",
|
| 628 |
+
"| 3 β Empathetic Negotiation | **50** | 25 | ~2.0 β 4.0 h |\n",
|
| 629 |
+
"| **Total** | **100** | **50** | **~4.0 β 8.0 h** |\n",
|
| 630 |
+
"\n",
|
| 631 |
+
"Checkpoints are pushed to HF Hub every `PUSH_EVERY_EPS=20` episodes, so if\n",
|
| 632 |
+
"the Kaggle session expires mid-run you can resume in a fresh session via\n",
|
| 633 |
+
"`HF_RESUME_REPO` in cell 8.\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"> **Want even faster?** Drop `PHASE_EPISODE_BUDGETS` to `{1: 10, 2: 15, 3: 25}`\n",
|
| 636 |
+
"> in cell 9 (50 episodes total, ~2.0 β 4.0 h). The curve will be choppier but\n",
|
| 637 |
+
"> still shows phase transitions cleanly.\n",
|
| 638 |
+
">\n",
|
| 639 |
+
"> **Want adaptive (early-stop)?** Set `PHASE_EPISODE_BUDGETS = None` in cell 9\n",
|
| 640 |
+
"> and `EARLY_STOP_ENABLED = True`; each phase will end the moment its reward\n",
|
| 641 |
+
"> target is sustained for `CONVERGENCE_WINDOW=3` consecutive groups."
|
| 642 |
]
|
| 643 |
},
|
| 644 |
{
|
|
|
|
| 647 |
"metadata": {},
|
| 648 |
"outputs": [],
|
| 649 |
"source": [
|
| 650 |
+
"# === CELL 13 β REAL TRAINING (4-6 h cell, fixed-budget curriculum) ===\n",
|
| 651 |
"metrics = train(\n",
|
| 652 |
" num_episodes=NUM_EPISODES,\n",
|
| 653 |
" group_size=GROUP_SIZE,\n",
|
|
|
|
| 663 |
" phase_min_win_rate=PHASE_MIN_WIN_RATE,\n",
|
| 664 |
" convergence_window=CONVERGENCE_WINDOW,\n",
|
| 665 |
" early_stop=EARLY_STOP_ENABLED,\n",
|
| 666 |
+
" phase_episode_budgets=PHASE_EPISODE_BUDGETS, # None -> early-stop mode\n",
|
| 667 |
")\n",
|
| 668 |
"print(f\"\\nTraining returned {len(metrics)} metric records.\")"
|
| 669 |
]
|