nihalaninihal commited on
Commit
5e0f2b1
·
1 Parent(s): 173a3e9

Update metrics format with drift/oversight tracking, add colab training notebook

Browse files
sentinelops_arena/metrics.py CHANGED
@@ -626,6 +626,27 @@ def format_comparison_metrics_html(
626
  ),
627
  [f"SE attacks: {u['social_eng_total']} / {t['social_eng_total']}"],
628
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  ]
630
 
631
  return (
 
626
  ),
627
  [f"SE attacks: {u['social_eng_total']} / {t['social_eng_total']}"],
628
  ),
629
+ _comparison_card(
630
+ "Oversight Accuracy",
631
+ _pct(u.get("oversight_accuracy", 0.0)),
632
+ _pct(t.get("oversight_accuracy", 0.0)),
633
+ _color_good_high(u.get("oversight_accuracy", 0.0)),
634
+ _color_good_high(t.get("oversight_accuracy", 0.0)),
635
+ _diff_indicator(u.get("oversight_accuracy", 0.0), t.get("oversight_accuracy", 0.0), lower_is_better=False),
636
+ [
637
+ f"Decisions: {u.get('total_oversight', 0)} / {t.get('total_oversight', 0)}",
638
+ f"Avg Expl Qual: {u.get('avg_explanation_quality', 0.0):.2f} / {t.get('avg_explanation_quality', 0.0):.2f}"
639
+ ],
640
+ ),
641
+ _comparison_card(
642
+ "Drift Adaptation",
643
+ _pct(u.get("drift_adaptation_rate", 0.0)),
644
+ _pct(t.get("drift_adaptation_rate", 0.0)),
645
+ _color_good_high(u.get("drift_adaptation_rate", 0.0)),
646
+ _color_good_high(t.get("drift_adaptation_rate", 0.0)),
647
+ _diff_indicator(u.get("drift_adaptation_rate", 0.0), t.get("drift_adaptation_rate", 0.0), lower_is_better=False),
648
+ [f"Detected: {u.get('drifts_detected', 0)} / {t.get('drifts_detected', 0)} of {u.get('drift_events', 0)}"],
649
+ ),
650
  ]
651
 
652
  return (
tasks/master_plan.md ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SentinelOps Arena -- Master Improvement Plan
2
+
3
+ **Generated:** Sunday March 8, 2026
4
+ **Deadline:** Sunday March 8, 2026 1:00 PM
5
+ **Synthesized from:** Researcher findings, code reviewer findings, sponsor track analysis, devil's advocate critique, gap analysis
6
+
7
+ ---
8
+
9
+ ## CONTEXT: Current State
10
+
11
+ The core environment is solid: 3 agents, 3 enterprise systems, 4 attack types, reward functions, randomized attacker, security metrics engine, and a polished Gradio UI with 4 tabs and a cybersecurity theme. The codebase compiles and the trained vs untrained worker comparison shows meaningful score differences.
12
+
13
+ **Three REQUIRED submission deliverables are NOT done:**
14
+ 1. HuggingFace Spaces deployment
15
+ 2. Google Colab training notebook
16
+ 3. Demo video on YouTube
17
+
18
+ **Partner tracks targeted:** Fleet AI ($10K, Scalable Oversight) and Patronus AI ($10K, Schema Drift)
19
+
20
+ ---
21
+
22
+ ## 1. CRITICAL FIXES (Must Do -- Submission Fails Without These)
23
+
24
+ ### C1. Deploy to HuggingFace Spaces
25
+ - **What:** Create HF Space, push code, verify it builds and runs
26
+ - **Files:** `requirements.txt`, `README.md` (frontmatter), `app.py`
27
+ - **Effort:** 30 min
28
+ - **Impact:** BLOCKER -- no live URL = no submission
29
+ - **Details:**
30
+ - Add `pandas>=2.0` to `requirements.txt` (missing, app.py imports it)
31
+ - Verify `gradio>=6.0.0` in requirements.txt matches README frontmatter `sdk_version: 6.9.0`
32
+ - Create Space at `huggingface.co/new-space`, SDK: Gradio, Hardware: CPU Basic
33
+ - Push with `git push hf main` or use `huggingface_hub.upload_folder()`
34
+ - Test all 4 tabs work on the live URL
35
+
36
+ ### C2. Create Colab Training Notebook
37
+ - **What:** Create `training/colab_training.ipynb` with working GRPO pipeline
38
+ - **Files:** New file: `training/colab_training.ipynb`
39
+ - **Effort:** 60-90 min
40
+ - **Impact:** BLOCKER -- submission requires "Minimal Training Script"
41
+ - **Details:**
42
+ - Reuse logic from `train.py` (it has everything needed)
43
+ - Use `Qwen/Qwen2.5-0.5B-Instruct` (fits free Colab T4)
44
+ - Use Unsloth for model loading, vanilla TRL GRPOTrainer for training
45
+ - Must show: env verification, data collection, model loading, GRPO config, at least a few training steps
46
+ - If openenv-core fails on Colab Python version, bundle standalone env code
47
+ - Add markdown cells explaining each step, mention partner tracks
48
+
49
+ ### C3. Record Demo Video
50
+ - **What:** 1-3 minute screen recording of Gradio app + voice/text narration
51
+ - **Files:** N/A (external -- YouTube upload)
52
+ - **Effort:** 30 min
53
+ - **Impact:** BLOCKER -- submission requires YouTube demo video
54
+ - **Details:**
55
+ - Show: episode replay (attack/adapt/flag cycle), untrained vs trained comparison, environment inspector
56
+ - Mention: 3-agent self-play, Fleet AI oversight, Patronus AI schema drift
57
+ - Keep simple -- QuickTime screen record, no fancy editing
58
+
59
+ ### C4. Verify Gradio App Launches Locally
60
+ - **What:** Run `python app.py` and test all 4 tabs
61
+ - **Files:** `app.py`, all imported modules
62
+ - **Effort:** 15 min
63
+ - **Impact:** HIGH -- if app crashes, HF Spaces will fail too
64
+ - **Note:** `tasks/todo.md` shows this is UNCHECKED
65
+
66
+ ---
67
+
68
+ ## 2. HIGH-IMPACT IMPROVEMENTS (Should Do -- Directly Impress Judges)
69
+
70
+ ### H1. Improve Oversight Explanation Quality Scoring (Fleet AI Track)
71
+ - **What:** Replace character-count explanation quality with structured quality scoring
72
+ - **Files:** `sentinelops_arena/environment.py:441`, `sentinelops_arena/demo.py:302-327`
73
+ - **Effort:** 20 min
74
+ - **Impact:** HIGH for Fleet AI ($10K) -- current scoring is `min(len(explanation) / 100.0, 1.0)` which is embarrassingly simplistic. Fleet AI judge Nicolai Ouporov will notice.
75
+ - **Details:**
76
+ - In `environment.py:441`, replace character-length heuristic with keyword-based quality scoring:
77
+ - +0.25 if explanation mentions the violation type (e.g., "policy violation", "social engineering")
78
+ - +0.25 if explanation references specific data (e.g., amount, field name, policy rule)
79
+ - +0.25 if explanation states the rule being violated (e.g., "max refund is $2000")
80
+ - +0.25 if explanation recommends corrective action
81
+ - In `demo.py` HeuristicOversight, improve the canned explanation strings to include specific data from the observation (e.g., "Worker issued refund exceeding policy max of $X. Current policy requires approval for amounts over $Y.")
82
+
83
+ ### H2. Add SLA Policy Drift to Ticketing (Patronus AI Track)
84
+ - **What:** Allow the attacker to change SLA deadlines, not just refund policies
85
+ - **Files:** `sentinelops_arena/systems/ticketing.py`, `sentinelops_arena/attacks.py`, `sentinelops_arena/demo.py`
86
+ - **Effort:** 20 min
87
+ - **Impact:** HIGH for Patronus AI ($10K) -- doubles the policy drift surface. Currently only billing has policy drift.
88
+ - **Details:**
89
+ - Add `TicketingSystem.apply_policy_drift(changes)` in `ticketing.py` that modifies `self.sla_rules`
90
+ - In `attacks.py:_execute_policy_drift()`, route to ticketing system when target is TICKETING
91
+ - In `demo.py` RandomizedAttacker, add SLA policy drift options to `POLICY_DRIFT_CHANGES`
92
+ - Worker should call `get_current_policy("sla")` to discover changed SLA rules
93
+
94
+ ### H3. Add Oversight Metrics to Dashboard
95
+ - **What:** Add oversight-specific metrics (explanation quality, detection accuracy) to the metrics engine and Gradio UI
96
+ - **Files:** `sentinelops_arena/metrics.py`, `app.py`
97
+ - **Effort:** 25 min
98
+ - **Impact:** HIGH for Fleet AI ($10K) -- currently NO oversight-specific metrics exist in the dashboard
99
+ - **Details:**
100
+ - In `metrics.py`, add to `compute_episode_metrics()`:
101
+ - `oversight_accuracy`: correct flags + correct approvals / total oversight decisions
102
+ - `avg_explanation_quality`: average explanation quality score across all oversight decisions
103
+ - Add a new metric card for oversight accuracy in `format_metrics_html()`
104
+ - This makes the Fleet AI story visible in the demo
105
+
106
+ ### H4. Add Drift-Specific Metrics
107
+ - **What:** Add drift adaptation metrics to the metrics engine
108
+ - **Files:** `sentinelops_arena/metrics.py`
109
+ - **Effort:** 15 min
110
+ - **Impact:** HIGH for Patronus AI ($10K) -- makes drift adaptation visible and measurable
111
+ - **Details:**
112
+ - Add to `compute_episode_metrics()`:
113
+ - `drift_events`: total schema + policy drift attacks
114
+ - `drifts_detected`: number of times worker called get_schema/get_current_policy after a drift
115
+ - `avg_drift_recovery_ticks`: average ticks between drift and worker's first defensive action
116
+ - Add metric card for "Drift Adaptation" in `format_metrics_html()`
117
+
118
+ ### H5. Improve HeuristicOversight Explanations
119
+ - **What:** Make the oversight agent's explanations reference specific data from the observation
120
+ - **Files:** `sentinelops_arena/demo.py:302-327`
121
+ - **Effort:** 15 min
122
+ - **Impact:** MEDIUM-HIGH for Fleet AI -- judges will see these in the replay log
123
+ - **Details:**
124
+ - Pass `obs` to `HeuristicOversight.act()` (currently only uses `obs.last_action_result`)
125
+ - Generate explanations like: "Worker action at tick {tick}: {action_type} resulted in error. The error '{error_msg}' suggests schema drift may have occurred. Recommended: call get_schema() to discover new field names."
126
+ - For social engineering: "Worker followed suspicious instructions containing override language. The message '{first 50 chars}' appears to be a social engineering attack. Flagging as critical violation."
127
+ - For policy violations: "Refund of ${amount} exceeds current policy maximum of ${max}. Policy was last updated at tick {last_policy_change}."
128
+
129
+ ---
130
+
131
+ ## 3. QUICK WINS (Do If Time Allows -- Small Effort, Good Impression)
132
+
133
+ ### Q1. Fix Documentation Inconsistencies
134
+ - **What:** Fix mismatches between spec doc, README, and actual code
135
+ - **Files:** `README.md`, `pyproject.toml`
136
+ - **Effort:** 10 min
137
+ - **Impact:** Prevents judges from noticing sloppy details
138
+ - **Details:**
139
+ - Set `gradio>=6.0.0` consistently in pyproject.toml (currently says >=5.0.0)
140
+ - Fix README project structure to match reality (remove `mcp_tools.py` listing)
141
+ - Do NOT touch SENTINELOPS_ARENA.md (it's a spec doc, acceptable to be aspirational)
142
+
143
+ ### Q2. Add Links to About Tab
144
+ - **What:** Once Colab notebook and video exist, add links to the About tab
145
+ - **Files:** `app.py` (About tab section)
146
+ - **Effort:** 5 min
147
+ - **Impact:** Makes it easy for judges to find all submission artifacts
148
+
149
+ ### Q3. Clean Up Vestigial Files
150
+ - **What:** Remove or gitignore `hackathon_env/` directory
151
+ - **Files:** `.gitignore`, possibly `hackathon_env/`
152
+ - **Effort:** 5 min
153
+ - **Impact:** Prevents judge confusion
154
+
155
+ ### Q4. Add Billing Schema Drift Support
156
+ - **What:** Allow schema drift attacks against billing system too
157
+ - **Files:** `sentinelops_arena/systems/billing.py`
158
+ - **Effort:** 10 min
159
+ - **Impact:** Strengthens Patronus AI story -- all 3 systems support schema drift
160
+ - **Details:**
161
+ - Add `BillingSystem.apply_schema_drift(old_field, new_field)` mirroring CRM pattern
162
+ - Add `_field_map` dict and `_apply_field_map` method to BillingSystem
163
+ - Update `attacks.py` `VALID_TARGETS` for schema_drift to include BILLING
164
+
165
+ ---
166
+
167
+ ## 4. SKIP LIST (Not Worth the Time)
168
+
169
+ | Item | Reason |
170
+ |------|--------|
171
+ | Compound attacks (2-3 simultaneous) | 2+ hours, marginal judge impact |
172
+ | Compliance drift (new required fields) | 1+ hours, nice but not critical |
173
+ | A2A protocol | Already marked "Cut" in spec, not in submission requirements |
174
+ | Docker support | HF Spaces uses Gradio SDK directly |
175
+ | MCP-X gateway demo | MCP tools in environment.py are sufficient |
176
+ | Full GRPO convergence | Pipeline working is enough -- convergence not required |
177
+ | Real datetime-based SLA | Tick-based is fine for demo |
178
+ | Multi-GPU training | Overkill for hackathon |
179
+ | Refactoring codebase | No judge impact, waste of time |
180
+
181
+ ---
182
+
183
+ ## EXECUTION ORDER (Recommended)
184
+
185
+ **Phase 1 (0:00 - 0:15): Verify and fix basics**
186
+ 1. C4: Verify Gradio app launches locally
187
+ 2. Q1: Fix requirements.txt (add pandas) and pyproject.toml consistency
188
+
189
+ **Phase 2 (0:15 - 1:00): High-impact code improvements**
190
+ 3. H1: Improve oversight explanation quality scoring (20 min)
191
+ 4. H2: Add SLA policy drift to ticketing (20 min)
192
+ 5. H5: Improve HeuristicOversight explanations (15 min)
193
+
194
+ **Phase 3 (1:00 - 1:30): Metrics improvements**
195
+ 6. H3: Add oversight metrics to dashboard (25 min)
196
+ 7. H4: Add drift-specific metrics (15 min)
197
+
198
+ **Phase 4 (1:30 - 2:00): Deployment**
199
+ 8. C1: Deploy to HuggingFace Spaces (30 min)
200
+
201
+ **Phase 5 (2:00 - 3:15): Required deliverables**
202
+ 9. C2: Create Colab training notebook (75 min)
203
+
204
+ **Phase 6 (3:15 - 3:45): Video and submission**
205
+ 10. C3: Record demo video (30 min)
206
+
207
+ **Phase 7 (3:45 - 4:00): Final polish**
208
+ 11. Q2: Add links to About tab (5 min)
209
+ 12. Q3: Clean up vestigial files (5 min)
210
+ 13. Final push and submit (5 min)
211
+
212
+ ---
213
+
214
+ ## KEY JUDGE CONSIDERATIONS
215
+
216
+ - **Nicolai Ouporov (Fleet AI):** Cares about scalable oversight. Will check: Does the oversight agent actually explain violations well? Is explanation quality tracked? Does training improve oversight?
217
+ - **Darshan Deshpande (Patronus AI):** Cares about schema drift. Will check: How many drift types? Does the worker adapt? Is drift visible in the UI?
218
+ - **Daniel Han (Unsloth):** Cares about Unsloth/TRL integration. Will check: Does the Colab notebook use Unsloth correctly? Does training actually work?
219
+ - **Sanyam Bhutani (Meta):** Cares about OpenEnv quality. Will check: Is the environment well-structured? Does step/reset/state work properly?
220
+ - **Benjamin Burtenshaw (HuggingFace):** Cares about Hub deployment. Will check: Is the HF Space functional and polished?
training/colab_training.ipynb ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ }
16
+ },
17
+ "cells": [
18
+ {
19
+ "cell_type": "markdown",
20
+ "source": [
21
+ "# SentinelOps Arena \u2014 GRPO Training with Unsloth\n",
22
+ "\n",
23
+ "This notebook demonstrates how to train the **Worker Agent** using GRPO (Group Relative Policy Optimization) on the SentinelOps Arena environment.\n",
24
+ "\n",
25
+ "SentinelOps Arena is a multi-agent self-play RL environment for enterprise security training built on OpenEnv. We are targeting the **Fleet AI (Scalable Oversight)** and **Patronus AI (Schema Drift)** tracks."
26
+ ],
27
+ "metadata": {
28
+ "id": "intro"
29
+ }
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "source": [
34
+ "## 1. Setup Environment"
35
+ ],
36
+ "metadata": {
37
+ "id": "setup-header"
38
+ }
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {
44
+ "id": "install-deps"
45
+ },
46
+ "outputs": [],
47
+ "source": [
48
+ "!pip install \"openenv-core[core]>=0.2.0\" mcp fastmcp pydantic pandas\n",
49
+ "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
50
+ "!pip install --no-deps \"trl<0.9.0\" peft accelerate bitsandbytes"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {
57
+ "id": "clone-repo"
58
+ },
59
+ "outputs": [],
60
+ "source": [
61
+ "import os\n",
62
+ "if not os.path.exists(\"NexusEnv\"):\n",
63
+ " !git clone https://github.com/nihalnihalani/NexusEnv.git\n",
64
+ "import sys\n",
65
+ "sys.path.append(\"/content/NexusEnv\")"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "source": [
71
+ "## 2. Collect Training Data via Self-Play\n",
72
+ "\n",
73
+ "We run the environment using our heuristic agents to generate the initial \"prompts\" that the Worker agent will face during training."
74
+ ],
75
+ "metadata": {
76
+ "id": "collect-header"
77
+ }
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {
83
+ "id": "collect-data"
84
+ },
85
+ "outputs": [],
86
+ "source": [
87
+ "import json\n",
88
+ "from datasets import Dataset\n",
89
+ "from NexusEnv.train import build_training_dataset, WORKER_SYSTEM_PROMPT\n",
90
+ "\n",
91
+ "NUM_EPISODES = 5\n",
92
+ "print(f\"Collecting training data from {NUM_EPISODES} episodes...\")\n",
93
+ "dataset_raw = build_training_dataset(num_episodes=NUM_EPISODES, target_agent=\"worker\")\n",
94
+ "\n",
95
+ "prompts = []\n",
96
+ "for d in dataset_raw:\n",
97
+ " messages = [\n",
98
+ " {\"role\": \"system\", \"content\": WORKER_SYSTEM_PROMPT},\n",
99
+ " {\"role\": \"user\", \"content\": d[\"prompt\"]},\n",
100
+ " ]\n",
101
+ " prompts.append(messages)\n",
102
+ "\n",
103
+ "train_dataset = Dataset.from_dict({\"prompt\": prompts})\n",
104
+ "print(f\"Dataset generated with {len(train_dataset)} examples.\")"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "markdown",
109
+ "source": [
110
+ "## 3. Load Model with Unsloth\n",
111
+ "\n",
112
+ "We use `Qwen/Qwen2.5-0.5B-Instruct` as it fits comfortably in a free Colab T4 GPU."
113
+ ],
114
+ "metadata": {
115
+ "id": "load-header"
116
+ }
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {
122
+ "id": "load-model"
123
+ },
124
+ "outputs": [],
125
+ "source": [
126
+ "from unsloth import FastLanguageModel\n",
127
+ "\n",
128
+ "model_name = \"unsloth/Qwen2.5-0.5B-Instruct\"\n",
129
+ "\n",
130
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
131
+ " model_name=model_name,\n",
132
+ " max_seq_length=2048,\n",
133
+ " load_in_4bit=True,\n",
134
+ " fast_inference=True, # Enable vLLM fast inference\n",
135
+ ")\n",
136
+ "\n",
137
+ "model = FastLanguageModel.get_peft_model(\n",
138
+ " model,\n",
139
+ " r=16,\n",
140
+ " target_modules=[\n",
141
+ " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
142
+ " \"gate_proj\", \"up_proj\", \"down_proj\",\n",
143
+ " ],\n",
144
+ " lora_alpha=16,\n",
145
+ " lora_dropout=0,\n",
146
+ " bias=\"none\",\n",
147
+ " use_gradient_checkpointing=\"unsloth\",\n",
148
+ ")"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "source": [
154
+ "## 4. GRPO Training\n",
155
+ "\n",
156
+ "We set up the GRPO configuration and launch the training process."
157
+ ],
158
+ "metadata": {
159
+ "id": "train-header"
160
+ }
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "id": "train"
167
+ },
168
+ "outputs": [],
169
+ "source": [
170
+ "from trl import GRPOConfig, GRPOTrainer\n",
171
+ "from NexusEnv.train import make_reward_function\n",
172
+ "\n",
173
+ "reward_fn = make_reward_function(\"worker\")\n",
174
+ "\n",
175
+ "grpo_config = GRPOConfig(\n",
176
+ " output_dir=\"./sentinelops-grpo-worker\",\n",
177
+ " num_train_epochs=1,\n",
178
+ " per_device_train_batch_size=2,\n",
179
+ " gradient_accumulation_steps=4,\n",
180
+ " num_generations=4,\n",
181
+ " max_completion_length=256,\n",
182
+ " max_prompt_length=512,\n",
183
+ " learning_rate=5e-6,\n",
184
+ " logging_steps=1,\n",
185
+ " report_to=\"none\",\n",
186
+ ")\n",
187
+ "\n",
188
+ "trainer = GRPOTrainer(\n",
189
+ " model=model,\n",
190
+ " processing_class=tokenizer,\n",
191
+ " reward_funcs=[reward_fn],\n",
192
+ " args=grpo_config,\n",
193
+ " train_dataset=train_dataset,\n",
194
+ ")\n",
195
+ "\n",
196
+ "trainer.train()"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "source": [
202
+ "## 5. Save the Trained Model\n",
203
+ "\n",
204
+ "Finally, we save our GRPO-trained LoRA weights."
205
+ ],
206
+ "metadata": {
207
+ "id": "save-header"
208
+ }
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {
214
+ "id": "save"
215
+ },
216
+ "outputs": [],
217
+ "source": [
218
+ "output_dir = \"./sentinelops-grpo-worker\"\n",
219
+ "trainer.save_model(output_dir)\n",
220
+ "tokenizer.save_pretrained(output_dir)\n",
221
+ "print(\"Model saved successfully!\")"
222
+ ]
223
+ }
224
+ ]
225
+ }