Spaces:
Running
Running
Commit ·
d52b449
1
Parent(s): fcf34b9
Fix format_comparison_metrics_html to accept run_comparison() dict directly
Browse files- metrics.py: format_comparison_metrics_html() now accepts either two
pre-computed metrics dicts (original API) or the single combined dict
returned by run_comparison(). When called with one arg, it auto-computes
metrics from each episode's log via compute_episode_metrics().
- colab_training.ipynb: Fix broken import paths in cells 5 and 11.
NexusEnv.train does not exist as a package — train.py lives at the repo
root, which is on sys.path after the git clone + sys.path.insert(). Both
cells now use `from train import ...` so they resolve correctly in Colab.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
sentinelops_arena/metrics.py
CHANGED
|
@@ -488,16 +488,33 @@ def format_metrics_html(metrics: dict[str, Any]) -> str:
|
|
| 488 |
|
| 489 |
def format_comparison_metrics_html(
|
| 490 |
untrained_metrics: dict[str, Any],
|
| 491 |
-
trained_metrics: dict[str, Any],
|
| 492 |
) -> str:
|
| 493 |
"""Render untrained vs. trained metrics side-by-side with diff indicators.
|
| 494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
Parameters
|
| 496 |
----------
|
| 497 |
untrained_metrics : dict
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
Returns
|
| 503 |
-------
|
|
@@ -505,6 +522,16 @@ def format_comparison_metrics_html(
|
|
| 505 |
Self-contained HTML snippet showing both metric sets with arrows
|
| 506 |
indicating improvement (green) or regression (red).
|
| 507 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
def _diff_indicator(
|
| 510 |
before: float,
|
|
|
|
| 488 |
|
| 489 |
def format_comparison_metrics_html(
|
| 490 |
untrained_metrics: dict[str, Any],
|
| 491 |
+
trained_metrics: dict[str, Any] | None = None,
|
| 492 |
) -> str:
|
| 493 |
"""Render untrained vs. trained metrics side-by-side with diff indicators.
|
| 494 |
|
| 495 |
+
Accepts two calling conventions:
|
| 496 |
+
|
| 497 |
+
1. **Two metrics dicts** (original API)::
|
| 498 |
+
|
| 499 |
+
format_comparison_metrics_html(untrained_metrics, trained_metrics)
|
| 500 |
+
|
| 501 |
+
2. **Combined results dict** (output of ``run_comparison()``)::
|
| 502 |
+
|
| 503 |
+
format_comparison_metrics_html(run_comparison(seed=42))
|
| 504 |
+
|
| 505 |
+
The combined dict must have the shape
|
| 506 |
+
``{"untrained": {"log": [...], ...}, "trained": {"log": [...], ...}}``.
|
| 507 |
+
Episode metrics are computed automatically from each log.
|
| 508 |
+
|
| 509 |
Parameters
|
| 510 |
----------
|
| 511 |
untrained_metrics : dict
|
| 512 |
+
Either a pre-computed metrics dict (original API) **or** the combined
|
| 513 |
+
comparison results dict returned by ``run_comparison()``.
|
| 514 |
+
trained_metrics : dict, optional
|
| 515 |
+
Pre-computed metrics for the trained episode. Required when
|
| 516 |
+
``untrained_metrics`` is a plain metrics dict. Omit when passing the
|
| 517 |
+
combined results dict.
|
| 518 |
|
| 519 |
Returns
|
| 520 |
-------
|
|
|
|
| 522 |
Self-contained HTML snippet showing both metric sets with arrows
|
| 523 |
indicating improvement (green) or regression (red).
|
| 524 |
"""
|
| 525 |
+
# --- Handle combined results dict from run_comparison() ---
|
| 526 |
+
if trained_metrics is None:
|
| 527 |
+
combined = untrained_metrics
|
| 528 |
+
if "untrained" not in combined or "trained" not in combined:
|
| 529 |
+
raise ValueError(
|
| 530 |
+
"When called with a single argument, the dict must contain "
|
| 531 |
+
"'untrained' and 'trained' keys (output of run_comparison())."
|
| 532 |
+
)
|
| 533 |
+
untrained_metrics = compute_episode_metrics(combined["untrained"]["log"])
|
| 534 |
+
trained_metrics = compute_episode_metrics(combined["trained"]["log"])
|
| 535 |
|
| 536 |
def _diff_indicator(
|
| 537 |
before: float,
|
training/colab_training.ipynb
CHANGED
|
@@ -61,7 +61,7 @@
|
|
| 61 |
"id": "collect-data"
|
| 62 |
},
|
| 63 |
"outputs": [],
|
| 64 |
-
"source": "from
|
| 65 |
},
|
| 66 |
{
|
| 67 |
"cell_type": "markdown",
|
|
@@ -109,7 +109,7 @@
|
|
| 109 |
"id": "save"
|
| 110 |
},
|
| 111 |
"outputs": [],
|
| 112 |
-
"source": "from trl import GRPOConfig, GRPOTrainer\nfrom
|
| 113 |
},
|
| 114 |
{
|
| 115 |
"cell_type": "markdown",
|
|
|
|
| 61 |
"id": "collect-data"
|
| 62 |
},
|
| 63 |
"outputs": [],
|
| 64 |
+
"source": "from train import collect_multi_agent_data, build_training_dataset\nfrom train import WORKER_SYSTEM_PROMPT, ATTACKER_SYSTEM_PROMPT, OVERSIGHT_SYSTEM_PROMPT\nfrom train import AGENT_CONFIGS\n\n# Run a single episode and show stats for each agent\nfor role in [\"worker\", \"attacker\", \"oversight\"]:\n data = collect_multi_agent_data(seed=42, target_agent=role)\n avg_r = sum(d[\"reward\"] for d in data) / max(len(data), 1)\n print(f\"{role:>10}: {len(data)} turns, avg_reward={avg_r:.3f}\")"
|
| 65 |
},
|
| 66 |
{
|
| 67 |
"cell_type": "markdown",
|
|
|
|
| 109 |
"id": "save"
|
| 110 |
},
|
| 111 |
"outputs": [],
|
| 112 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\nfrom train import make_reward_function\n\n# Environment-executing reward function\nreward_fn = make_reward_function(TARGET_AGENT)\n\ngrpo_config = GRPOConfig(\n output_dir=f\"./sentinelops-grpo-{TARGET_AGENT}\",\n max_steps=300, # Official recommendation\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=2, # GRPO group size\n max_completion_length=256,\n max_prompt_length=512,\n learning_rate=5e-5, # Official reference: 5e-5\n temperature=1.0, # Official reference: 1.0\n logging_steps=1,\n save_steps=50,\n report_to=\"none\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[reward_fn],\n args=grpo_config,\n train_dataset=train_dataset,\n)\n\nprint(f\"Starting GRPO training for {TARGET_AGENT}...\")\nprint(f\" max_steps={grpo_config.max_steps}, lr={grpo_config.learning_rate}\")\nprint(f\" num_generations={grpo_config.num_generations}, temp={grpo_config.temperature}\")\ntrainer.train()"
|
| 113 |
},
|
| 114 |
{
|
| 115 |
"cell_type": "markdown",
|