Kartik Goyal commited on
Commit
88c89bd
·
1 Parent(s): 39124e2

updated grpo logic

Browse files
Files changed (3) hide show
  1. README.md +0 -11
  2. grpo_train.ipynb +651 -538
  3. grpo_train.py +87 -23
README.md CHANGED
@@ -1,14 +1,3 @@
1
- ---
2
- title: MetaGuard Ad Policy Sandbox
3
- emoji: 🛡
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: docker
7
- app_port: 8000
8
- pinned: false
9
- license: mit
10
- ---
11
-
12
  # MetaGuard: A Multi-App RL Environment for Enterprise Ad Policy Compliance
13
 
14
  > An OpenEnv-compatible reinforcement learning environment that forces an LLM agent
 
 
 
 
 
 
 
 
 
 
 
 
1
  # MetaGuard: A Multi-App RL Environment for Enterprise Ad Policy Compliance
2
 
3
  > An OpenEnv-compatible reinforcement learning environment that forces an LLM agent
grpo_train.ipynb CHANGED
@@ -1,540 +1,653 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "A100"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  },
9
- "kernelspec": {
10
- "display_name": "Python 3",
11
- "name": "python3"
12
- },
13
- "language_info": {
14
- "name": "python"
15
- },
16
- "accelerator": "GPU"
17
- },
18
- "cells": [
19
- {
20
- "cell_type": "markdown",
21
- "metadata": {},
22
- "source": [
23
- "# 🛡️ MetaGuard — GRPO Training Notebook\n",
24
- "\n",
25
- "**Team:** Parth Singhal, Mehakveer Kaur, Kartik Goyal \n",
26
- "**HF Space:** https://huggingface.co/spaces/parth-1/MetaGuard \n",
27
- "**Hackathon:** OpenEnv — Meta × Scaler \n",
28
- "\n",
29
- "This notebook trains **Llama 3.1 8B** using GRPO on the MetaGuard Ad Policy Compliance environment.\n",
30
- "\n",
31
- "### What this trains:\n",
32
- "- Agent learns to follow structured SOP: `query_regulations → gather signals → submit_audit → decide`\n",
33
- "- Reward shaped by correctness, sequence compliance, API failure recovery\n",
34
- "- Environment hosted on HF Space (A100 runs both training + env)"
35
- ]
36
- },
37
- {
38
- "cell_type": "markdown",
39
- "metadata": {},
40
- "source": [
41
- "## Cell 1 — Install Dependencies"
42
- ]
43
- },
44
- {
45
- "cell_type": "code",
46
- "execution_count": null,
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "!pip install unsloth trl transformers datasets accelerate peft -q\n",
51
- "!pip install openenv-core==0.2.1 --no-deps -q\n",
52
- "!pip install fastapi uvicorn pydantic requests openai matplotlib -q\n",
53
- "print('✅ Dependencies installed')"
54
- ]
55
- },
56
- {
57
- "cell_type": "markdown",
58
- "metadata": {},
59
- "source": [
60
- "## Cell 2 — Clone Repo"
61
- ]
62
- },
63
- {
64
- "cell_type": "code",
65
- "execution_count": null,
66
- "metadata": {},
67
- "outputs": [],
68
- "source": [
69
- "import os\n",
70
- "\n",
71
- "if not os.path.exists('meta-ad-policy-sandbox'):\n",
72
- " !git clone https://github.com/Parth380/meta-ad-policy-sandbox.git\n",
73
- "\n",
74
- "%cd meta-ad-policy-sandbox\n",
75
- "print('✅ Repo ready')"
76
- ]
77
- },
78
- {
79
- "cell_type": "markdown",
80
- "metadata": {},
81
- "source": [
82
- "## Cell 3 — Config (SET THESE)"
83
- ]
84
- },
85
- {
86
- "cell_type": "code",
87
- "execution_count": null,
88
- "metadata": {},
89
- "outputs": [],
90
- "source": [
91
- "import os\n",
92
- "\n",
93
- "os.environ['ENV_URL'] = 'https://parth-1-metaguard.hf.space' # your HF Space URL\n",
94
- "os.environ['HF_REPO'] = 'parth-1/metaguard-llama3.1-8b-grpo' # model push destination\n",
95
- "os.environ['HF_TOKEN'] = '' # ← paste your HF write token here\n",
96
- "\n",
97
- "ENV_URL = os.environ['ENV_URL']\n",
98
- "HF_TOKEN = os.environ['HF_TOKEN']\n",
99
- "HF_REPO = os.environ['HF_REPO']\n",
100
- "\n",
101
- "print(f'ENV_URL : {ENV_URL}')\n",
102
- "print(f'HF_REPO : {HF_REPO}')\n",
103
- "print(f'HF_TOKEN : {\"set ✅\" if HF_TOKEN else \"MISSING ❌\"}')"
104
- ]
105
- },
106
- {
107
- "cell_type": "markdown",
108
- "metadata": {},
109
- "source": [
110
- "## Cell 4 — Wake Up HF Space"
111
- ]
112
- },
113
- {
114
- "cell_type": "code",
115
- "execution_count": null,
116
- "metadata": {},
117
- "outputs": [],
118
- "source": [
119
- "import requests, time\n",
120
- "\n",
121
- "print('Waking up HF Space...')\n",
122
- "for i in range(20):\n",
123
- " try:\n",
124
- " r = requests.post(\n",
125
- " f\"{ENV_URL}/reset\",\n",
126
- " json={'task_id': 'task_1_healthcare'},\n",
127
- " timeout=10\n",
128
- " )\n",
129
- " if r.status_code == 200:\n",
130
- " print(f'✅ Environment ready (attempt {i+1})')\n",
131
- " break\n",
132
- " except Exception as e:\n",
133
- " print(f' attempt {i+1}: waiting... ({e})')\n",
134
- " time.sleep(3)\n",
135
- "else:\n",
136
- " raise RuntimeError('❌ ENV not reachable after 20 attempts')"
137
- ]
138
- },
139
- {
140
- "cell_type": "markdown",
141
- "metadata": {},
142
- "source": [
143
- "## Cell 5 — Imports + Helpers"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": null,
149
- "metadata": {},
150
- "outputs": [],
151
- "source": [
152
- "import json\n",
153
- "import random\n",
154
- "import torch\n",
155
- "import matplotlib.pyplot as plt\n",
156
- "from collections import defaultdict\n",
157
- "\n",
158
- "from datasets import Dataset\n",
159
- "from unsloth import FastLanguageModel, PatchFastRL\n",
160
- "from trl import GRPOTrainer, GRPOConfig\n",
161
- "\n",
162
- "PatchFastRL('GRPO', FastLanguageModel)\n",
163
- "\n",
164
- "ALLOWED_ACTIONS = [\n",
165
- " 'query_regulations', 'analyze_image', 'check_advertiser_history',\n",
166
- " 'request_landing_page', 'request_id_verification',\n",
167
- " 'submit_audit', 'approve', 'reject',\n",
168
- "]\n",
169
- "\n",
170
- "class EnvClient:\n",
171
- " def __init__(self, url):\n",
172
- " self.url = url\n",
173
- " def reset(self, task_id):\n",
174
- " return requests.post(f'{self.url}/reset', json={'task_id': task_id}, timeout=8).json()\n",
175
- " def step(self, action):\n",
176
- " return requests.post(f'{self.url}/step', json={'action': action}, timeout=8).json()\n",
177
- "\n",
178
- "def safe_step(client, action):\n",
179
- " for _ in range(3):\n",
180
- " try:\n",
181
- " return client.step(action)\n",
182
- " except:\n",
183
- " time.sleep(0.5)\n",
184
- " return {'reward': -0.3}\n",
185
- "\n",
186
- "def extract_json(text):\n",
187
- " try:\n",
188
- " if '```' in text:\n",
189
- " text = text.split('```')[1]\n",
190
- " if text.startswith('json'):\n",
191
- " text = text[4:]\n",
192
- " return json.loads(text.strip())\n",
193
- " except:\n",
194
- " return None\n",
195
- "\n",
196
- "print('✅ Helpers loaded')"
197
- ]
198
- },
199
- {
200
- "cell_type": "markdown",
201
- "metadata": {},
202
- "source": [
203
- "## Cell 6 — Dataset"
204
- ]
205
- },
206
- {
207
- "cell_type": "code",
208
- "execution_count": null,
209
- "metadata": {},
210
- "outputs": [],
211
- "source": [
212
- "BASE_SCENARIOS = [\n",
213
- " {'task_id': 'task_1_healthcare',\n",
214
- " 'text': \"Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.\",\n",
215
- " 'actions_already_taken': [], 'setup_actions': []},\n",
216
- " {'task_id': 'task_2_financial',\n",
217
- " 'text': \"Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.\",\n",
218
- " 'actions_already_taken': [], 'setup_actions': []},\n",
219
- " {'task_id': 'task_3_multimodal',\n",
220
- " 'text': 'Multimodal ad: image may contain hidden violation. No actions taken yet.',\n",
221
- " 'actions_already_taken': [], 'setup_actions': []},\n",
222
- " {'task_id': 'task_6_conflict',\n",
223
- " 'text': 'High-trust advertiser but policy borderline. No actions taken yet.',\n",
224
- " 'actions_already_taken': [], 'setup_actions': []},\n",
225
- " {'task_id': 'task_7_ambiguous',\n",
226
- " 'text': 'Ambiguous wellness ad. Policy confidence low. No actions taken yet.',\n",
227
- " 'actions_already_taken': [], 'setup_actions': []},\n",
228
- " {'task_id': 'task_8_adversarial',\n",
229
- " 'text': 'Natural supplement ad. Image may hide violation. No actions taken yet.',\n",
230
- " 'actions_already_taken': [], 'setup_actions': []},\n",
231
- " {'task_id': 'task_9_dependency_trap',\n",
232
- " 'text': 'Certified wellness product. Regulations and CRM look clean. No actions taken yet.',\n",
233
- " 'actions_already_taken': [], 'setup_actions': []},\n",
234
- " {'task_id': 'task_10_failure',\n",
235
- " 'text': 'Healthcare ad. First API call may fail. No actions taken yet.',\n",
236
- " 'actions_already_taken': [], 'setup_actions': []},\n",
237
- " {'task_id': 'task_1_healthcare',\n",
238
- " 'text': 'Healthcare ad. Policy already queried.',\n",
239
- " 'actions_already_taken': ['query_regulations'],\n",
240
- " 'setup_actions': [{'action_type': 'query_regulations', 'reasoning': 'policy lookup'}]},\n",
241
- " {'task_id': 'task_2_financial',\n",
242
- " 'text': 'Financial ad. Policy and history checked. Submit audit next.',\n",
243
- " 'actions_already_taken': ['query_regulations', 'check_advertiser_history'],\n",
244
- " 'setup_actions': [\n",
245
- " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n",
246
- " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n",
247
- " ]},\n",
248
- " {'task_id': 'task_2_financial',\n",
249
- " 'text': 'Financial ad. Policy, history, audit all complete. Make final decision.',\n",
250
- " 'actions_already_taken': ['query_regulations', 'check_advertiser_history', 'submit_audit'],\n",
251
- " 'setup_actions': [\n",
252
- " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n",
253
- " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n",
254
- " {'action_type': 'submit_audit', 'reasoning': 'audit log'},\n",
255
- " ]},\n",
256
- "]\n",
257
- "\n",
258
- "PROMPT_TEMPLATE = \"\"\"You are an enterprise Ad Policy Compliance Agent.\n",
259
- "\n",
260
- "You MUST choose exactly ONE action_type from this list (any other value is invalid):\n",
261
- "- query_regulations\n",
262
- "- analyze_image\n",
263
- "- check_advertiser_history\n",
264
- "- request_landing_page\n",
265
- "- request_id_verification\n",
266
- "- submit_audit\n",
267
- "- approve\n",
268
- "- reject\n",
269
- "\n",
270
- "REQUIRED PHASE ORDER:\n",
271
- "1. query_regulations -> always first\n",
272
- "2. analyze_image / check_advertiser_history -> gather signals\n",
273
- "3. submit_audit -> always before final decision\n",
274
- "4. approve OR reject -> only after audit\n",
275
- "\n",
276
- "HARD RULES:\n",
277
- "- NEVER repeat an action listed in `actions_already_taken`.\n",
278
- "- Respond with ONLY a valid JSON object. No markdown, no prose.\n",
279
- "\n",
280
- "Required format:\n",
281
- "{{\\\"action_type\\\": \\\"<one_of_the_actions_above>\\\", \\\"reasoning\\\": \\\"<short reason>\\\"}}\n",
282
- "\n",
283
- "Scenario: {text}\n",
284
- "actions_already_taken: {actions_already_taken}\n",
285
- "\n",
286
- "Your next action?\"\"\"\n",
287
- "\n",
288
- "def build_dataset():\n",
289
- " rows = []\n",
290
- " for s in BASE_SCENARIOS:\n",
291
- " prompt = PROMPT_TEMPLATE.format(\n",
292
- " text=s['text'],\n",
293
- " actions_already_taken=json.dumps(s['actions_already_taken']),\n",
294
- " )\n",
295
- " rows.append({'prompt': prompt, 'task_id': s['task_id'], 'setup_actions': s['setup_actions']})\n",
296
- " return Dataset.from_list(rows * 10)\n",
297
- "\n",
298
- "dataset = build_dataset()\n",
299
- "print(f'✅ Dataset: {len(dataset)} examples')"
300
- ]
301
- },
302
- {
303
- "cell_type": "markdown",
304
- "metadata": {},
305
- "source": [
306
- "## Cell 7 — Reward Function"
307
- ]
308
- },
309
- {
310
- "cell_type": "code",
311
- "execution_count": null,
312
- "metadata": {},
313
- "outputs": [],
314
- "source": [
315
- "# Track rewards for plotting\n",
316
- "reward_log = []\n",
317
- "step_log = []\n",
318
- "global_step_counter = [0]\n",
319
- "\n",
320
- "def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):\n",
321
- " client = EnvClient(ENV_URL)\n",
322
- " rewards = []\n",
323
- "\n",
324
- " for completion, t_id, setup in zip(completions, task_id, setup_actions):\n",
325
- " parsed = extract_json(completion)\n",
326
- " if not parsed:\n",
327
- " rewards.append(-1.0)\n",
328
- " continue\n",
329
- "\n",
330
- " action_type = parsed.get('action_type')\n",
331
- " if action_type not in ALLOWED_ACTIONS:\n",
332
- " rewards.append(-1.0)\n",
333
- " continue\n",
334
- "\n",
335
- " action = {\n",
336
- " 'action_type': action_type,\n",
337
- " 'reasoning': parsed.get('reasoning', 'format-compliant'),\n",
338
- " }\n",
339
- "\n",
340
- " try:\n",
341
- " client.reset(t_id)\n",
342
- " for s in setup:\n",
343
- " safe_step(client, s)\n",
344
- "\n",
345
- " result = safe_step(client, action)\n",
346
- " env_reward = float(result.get('reward', -0.2))\n",
347
- " status_msg = (result.get('status_message') or '').lower()\n",
348
- "\n",
349
- " rejected = (\n",
350
- " 'api failure' in status_msg\n",
351
- " or 'invalid action' in status_msg\n",
352
- " or 'must call' in status_msg\n",
353
- " )\n",
354
- " shaped = -0.5 if rejected else 0.5 + env_reward\n",
355
- " rewards.append(shaped)\n",
356
- "\n",
357
- " except Exception:\n",
358
- " rewards.append(-0.3)\n",
359
- "\n",
360
- " # Log for plot\n",
361
- " avg = sum(rewards) / len(rewards) if rewards else 0.0\n",
362
- " global_step_counter[0] += 1\n",
363
- " reward_log.append(avg)\n",
364
- " step_log.append(global_step_counter[0])\n",
365
- "\n",
366
- " return rewards\n",
367
- "\n",
368
- "print('✅ Reward function ready')"
369
- ]
370
- },
371
- {
372
- "cell_type": "markdown",
373
- "metadata": {},
374
- "source": [
375
- "## Cell 8 — Load Model"
376
- ]
377
- },
378
- {
379
- "cell_type": "code",
380
- "execution_count": null,
381
- "metadata": {},
382
- "outputs": [],
383
- "source": [
384
- "USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 40 * 1024**3\n",
385
- "print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"CPU\"}')\n",
386
- "print(f'4-bit quant: {USE_4BIT}')\n",
387
- "\n",
388
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
389
- " model_name='unsloth/Llama-3.1-8B-Instruct',\n",
390
- " load_in_4bit=USE_4BIT,\n",
391
- " max_seq_length=2048,\n",
392
- " dtype=None,\n",
393
- ")\n",
394
- "\n",
395
- "model = FastLanguageModel.get_peft_model(\n",
396
- " model,\n",
397
- " r=32,\n",
398
- " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
399
- " lora_alpha=64,\n",
400
- " lora_dropout=0,\n",
401
- " bias='none',\n",
402
- " use_gradient_checkpointing='unsloth',\n",
403
- " random_state=3407,\n",
404
- ")\n",
405
- "print('✅ Model loaded')"
406
- ]
407
- },
408
- {
409
- "cell_type": "markdown",
410
- "metadata": {},
411
- "source": [
412
- "## Cell 9 — Train"
413
- ]
414
- },
415
- {
416
- "cell_type": "code",
417
- "execution_count": null,
418
- "metadata": {},
419
- "outputs": [],
420
- "source": [
421
- "trainer = GRPOTrainer(\n",
422
- " model=model,\n",
423
- " reward_funcs=[reward_environment],\n",
424
- " args=GRPOConfig(\n",
425
- " output_dir='outputs',\n",
426
- " learning_rate=2e-5,\n",
427
- " num_train_epochs=3,\n",
428
- " per_device_train_batch_size=2,\n",
429
- " gradient_accumulation_steps=4,\n",
430
- " num_generations=4,\n",
431
- " max_prompt_length=768,\n",
432
- " max_completion_length=128,\n",
433
- " logging_steps=5,\n",
434
- " warmup_ratio=0.1,\n",
435
- " bf16=True,\n",
436
- " report_to='none',\n",
437
- " ),\n",
438
- " train_dataset=dataset,\n",
439
- " tokenizer=tokenizer,\n",
440
- ")\n",
441
- "\n",
442
- "print('🚀 Starting GRPO training...')\n",
443
- "trainer.train()\n",
444
- "print('✅ Training complete')"
445
- ]
446
- },
447
- {
448
- "cell_type": "markdown",
449
- "metadata": {},
450
- "source": [
451
- "## Cell 10 — Plot Reward Curve"
452
- ]
453
- },
454
- {
455
- "cell_type": "code",
456
- "execution_count": null,
457
- "metadata": {},
458
- "outputs": [],
459
- "source": [
460
- "import matplotlib.pyplot as plt\n",
461
- "import numpy as np\n",
462
- "\n",
463
- "# Smooth with moving average\n",
464
- "def moving_avg(data, window=5):\n",
465
- " if len(data) < window:\n",
466
- " return data\n",
467
- " return np.convolve(data, np.ones(window)/window, mode='valid')\n",
468
- "\n",
469
- "fig, ax = plt.subplots(figsize=(10, 5))\n",
470
- "ax.plot(step_log, reward_log, alpha=0.3, color='steelblue', label='Raw reward')\n",
471
- "smoothed = moving_avg(reward_log)\n",
472
- "ax.plot(range(len(smoothed)), smoothed, color='steelblue', linewidth=2, label='Smoothed (MA-5)')\n",
473
- "ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n",
474
- "ax.set_xlabel('Training Step')\n",
475
- "ax.set_ylabel('Avg Reward per Batch')\n",
476
- "ax.set_title('MetaGuard GRPO — Reward Curve')\n",
477
- "ax.legend()\n",
478
- "ax.grid(alpha=0.3)\n",
479
- "\n",
480
- "plt.tight_layout()\n",
481
- "plt.savefig('outputs/reward_plot.png', dpi=150)\n",
482
- "plt.show()\n",
483
- "print('✅ Plot saved to outputs/reward_plot.png')\n",
484
- "\n",
485
- "# Print before/after summary\n",
486
- "n = len(reward_log)\n",
487
- "first_10 = reward_log[:min(10, n)]\n",
488
- "last_10 = reward_log[max(0, n-10):]\n",
489
- "print(f'\\n--- Results ---')\n",
490
- "print(f'Avg reward (first 10 steps): {sum(first_10)/len(first_10):.3f}')\n",
491
- "print(f'Avg reward (last 10 steps) : {sum(last_10)/len(last_10):.3f}')"
492
- ]
493
- },
494
- {
495
- "cell_type": "markdown",
496
- "metadata": {},
497
- "source": [
498
- "## Cell 11 — Save + Push to HF Hub"
499
- ]
500
- },
501
- {
502
- "cell_type": "code",
503
- "execution_count": null,
504
- "metadata": {},
505
- "outputs": [],
506
- "source": [
507
- "model.save_pretrained('outputs/lora_adapter')\n",
508
- "tokenizer.save_pretrained('outputs/lora_adapter')\n",
509
- "print('✅ LoRA adapter saved')\n",
510
- "\n",
511
- "print('Merging adapter into base model (bf16)...')\n",
512
- "merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n",
513
- " model_name='outputs/lora_adapter',\n",
514
- " load_in_4bit=False,\n",
515
- " max_seq_length=2048,\n",
516
- ")\n",
517
- "merged_model.save_pretrained_merged(\n",
518
- " 'outputs/merged',\n",
519
- " merged_tokenizer,\n",
520
- " save_method='merged_16bit',\n",
521
- ")\n",
522
- "print('✅ Merged model saved')\n",
523
- "\n",
524
- "if HF_REPO and HF_TOKEN:\n",
525
- " print(f'Pushing to {HF_REPO}...')\n",
526
- " merged_model.push_to_hub_merged(\n",
527
- " HF_REPO,\n",
528
- " merged_tokenizer,\n",
529
- " save_method='merged_16bit',\n",
530
- " token=HF_TOKEN,\n",
531
- " )\n",
532
- " print(f'✅ Model live at https://huggingface.co/{HF_REPO}')\n",
533
- "else:\n",
534
- " print('⚠️ Set HF_REPO and HF_TOKEN to push to Hub')\n",
535
- "\n",
536
- "print('Done.')"
537
- ]
538
- }
539
- ]
540
- }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🛡️ MetaGuard — GRPO Training Notebook\n",
8
+ "\n",
9
+ "**Team:** Parth Singhal, Mehakveer Kaur, Kartik Goyal \n",
10
+ "**HF Space:** https://huggingface.co/spaces/parth-1/MetaGuard \n",
11
+ "**Hackathon:** OpenEnv — Meta × Scaler \n",
12
+ "\n",
13
+ "This notebook trains **Llama 3.1 8B** using GRPO on the MetaGuard Ad Policy Compliance environment.\n",
14
+ "\n",
15
+ "### What this trains:\n",
16
+ "- Agent learns to follow structured SOP: `query_regulations → gather signals → submit_audit → decide`\n",
17
+ "- Reward shaped by correctness, sequence compliance, API failure recovery\n",
18
+ "- Environment runs locally in the notebook (fast); GPU handles only the model"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## Cell 1 — Install Dependencies"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "metadata": {},
31
+ "source": [
32
+ "!pip install unsloth trl transformers datasets accelerate peft -q\n",
33
+ "!pip install openenv-core==0.2.1 --no-deps -q\n",
34
+ "!pip install fastapi uvicorn pydantic requests openai matplotlib -q\n",
35
+ "print('✅ Dependencies installed')"
36
+ ],
37
+ "execution_count": null,
38
+ "outputs": []
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "## Cell 2 — Clone Repo"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "metadata": {},
50
+ "source": [
51
+ "import os\n",
52
+ "\n",
53
+ "if not os.path.exists('meta-ad-policy-sandbox'):\n",
54
+ " !git clone https://github.com/Parth380/meta-ad-policy-sandbox.git\n",
55
+ "\n",
56
+ "%cd meta-ad-policy-sandbox\n",
57
+ "!pip install -e . -q\n",
58
+ "os.makedirs('outputs', exist_ok=True)\n",
59
+ "print('Repo installed & outputs/ ready')"
60
+ ],
61
+ "execution_count": null,
62
+ "outputs": []
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "metadata": {},
67
+ "source": [
68
+ "## Cell 3 — Config (SET THESE)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "metadata": {},
74
+ "source": [
75
+ "import os\n",
76
+ "\n",
77
+ "os.environ['ENV_URL'] = 'http://localhost:8000' # local env (fast); change to HF Space URL if needed\n",
78
+ "os.environ['HF_REPO'] = 'parth-1/metaguard-llama3.1-8b-grpo'\n",
79
+ "os.environ['HF_TOKEN'] = '' # paste your HF write token here\n",
80
+ "\n",
81
+ "ENV_URL = os.environ['ENV_URL']\n",
82
+ "HF_TOKEN = os.environ['HF_TOKEN']\n",
83
+ "HF_REPO = os.environ['HF_REPO']\n",
84
+ "\n",
85
+ "print(f'ENV_URL : {ENV_URL}')\n",
86
+ "print(f'HF_REPO : {HF_REPO}')\n",
87
+ "print(f'HF_TOKEN : {\"set\" if HF_TOKEN else \"MISSING -- set above before Cell 11\"}')"
88
+ ],
89
+ "execution_count": null,
90
+ "outputs": []
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {},
95
+ "source": [
96
+ "## Cell 4 — Boot Local Environment"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "metadata": {},
102
+ "source": [
103
+ "import subprocess, time, threading, requests\n",
104
+ "import uvicorn\n",
105
+ "\n",
106
+ "procs = [\n",
107
+ " subprocess.Popen(['python', 'apps/regulatory_api.py']),\n",
108
+ " subprocess.Popen(['python', 'apps/crm_api.py']),\n",
109
+ " subprocess.Popen(['python', 'apps/audit_api.py']),\n",
110
+ "]\n",
111
+ "time.sleep(3)\n",
112
+ "\n",
113
+ "from server.app import app as _env_app\n",
114
+ "threading.Thread(\n",
115
+ " target=uvicorn.run,\n",
116
+ " kwargs={'app': _env_app, 'host': '0.0.0.0', 'port': 8000, 'log_level': 'warning'},\n",
117
+ " daemon=True,\n",
118
+ ").start()\n",
119
+ "time.sleep(2)\n",
120
+ "\n",
121
+ "for i in range(20):\n",
122
+ " try:\n",
123
+ " r = requests.post(f'{ENV_URL}/reset', json={'task_id': 'task_1_healthcare'}, timeout=5)\n",
124
+ " if r.status_code == 200:\n",
125
+ " print(f'Environment ready (attempt {i+1})')\n",
126
+ " break\n",
127
+ " except:\n",
128
+ " pass\n",
129
+ " time.sleep(1)\n",
130
+ "else:\n",
131
+ " raise RuntimeError('ENV not reachable after 20 attempts')"
132
+ ],
133
+ "execution_count": null,
134
+ "outputs": []
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "metadata": {},
139
+ "source": [
140
+ "## Cell 5 — Imports + Helpers"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "metadata": {},
146
+ "source": [
147
+ "import json\n",
148
+ "import random\n",
149
+ "import torch\n",
150
+ "import matplotlib.pyplot as plt\n",
151
+ "from collections import defaultdict\n",
152
+ "\n",
153
+ "from datasets import Dataset\n",
154
+ "from unsloth import FastLanguageModel, PatchFastRL\n",
155
+ "from trl import GRPOTrainer, GRPOConfig\n",
156
+ "\n",
157
+ "PatchFastRL('GRPO', FastLanguageModel)\n",
158
+ "\n",
159
+ "ALLOWED_ACTIONS = [\n",
160
+ " 'query_regulations', 'analyze_image', 'check_advertiser_history',\n",
161
+ " 'request_landing_page', 'request_id_verification',\n",
162
+ " 'submit_audit', 'approve', 'reject',\n",
163
+ "]\n",
164
+ "\n",
165
+ "class EnvClient:\n",
166
+ " def __init__(self, url):\n",
167
+ " self.url = url\n",
168
+ " def reset(self, task_id):\n",
169
+ " return requests.post(f'{self.url}/reset', json={'task_id': task_id}, timeout=8).json()\n",
170
+ " def step(self, action):\n",
171
+ " return requests.post(f'{self.url}/step', json={'action': action}, timeout=8).json()\n",
172
+ "\n",
173
+ "def safe_step(client, action):\n",
174
+ " for _ in range(3):\n",
175
+ " try:\n",
176
+ " return client.step(action)\n",
177
+ " except:\n",
178
+ " time.sleep(0.5)\n",
179
+ " return {'reward': -0.3}\n",
180
+ "\n",
181
+ "def extract_json(text):\n",
182
+ " try:\n",
183
+ " if '```' in text:\n",
184
+ " text = text.split('```')[1]\n",
185
+ " if text.startswith('json'):\n",
186
+ " text = text[4:]\n",
187
+ " return json.loads(text.strip())\n",
188
+ " except:\n",
189
+ " return None\n",
190
+ "\n",
191
+ "print('✅ Helpers loaded')"
192
+ ],
193
+ "execution_count": null,
194
+ "outputs": []
195
+ },
196
+ {
197
+ "cell_type": "markdown",
198
+ "metadata": {},
199
+ "source": [
200
+ "## Cell 6 — Dataset"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "metadata": {},
206
+ "source": [
207
+ "BASE_SCENARIOS = [\n",
208
+ " {'task_id': 'task_1_healthcare',\n",
209
+ " 'text': \"Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.\",\n",
210
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
211
+ " {'task_id': 'task_2_financial',\n",
212
+ " 'text': \"Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.\",\n",
213
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
214
+ " {'task_id': 'task_3_multimodal',\n",
215
+ " 'text': 'Multimodal ad: image may contain hidden violation. No actions taken yet.',\n",
216
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
217
+ " {'task_id': 'task_6_conflict',\n",
218
+ " 'text': 'High-trust advertiser but policy borderline. No actions taken yet.',\n",
219
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
220
+ " {'task_id': 'task_7_ambiguous',\n",
221
+ " 'text': 'Ambiguous wellness ad. Policy confidence low. No actions taken yet.',\n",
222
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
223
+ " {'task_id': 'task_8_adversarial',\n",
224
+ " 'text': 'Natural supplement ad. Image may hide violation. No actions taken yet.',\n",
225
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
226
+ " {'task_id': 'task_9_dependency_trap',\n",
227
+ " 'text': 'Certified wellness product. Regulations and CRM look clean. No actions taken yet.',\n",
228
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
229
+ " {'task_id': 'task_10_failure',\n",
230
+ " 'text': 'Healthcare ad. First API call may fail. No actions taken yet.',\n",
231
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
232
+ " # task_4 targeting — fresh\n",
233
+ " {'task_id': 'task_4_targeting',\n",
234
+ " 'text': \"Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.\",\n",
235
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
236
+ " # task_4 targeting — mid state\n",
237
+ " {'task_id': 'task_4_targeting',\n",
238
+ " 'text': 'Financial ad targeting young users. Policy queried, need to verify age targeting.',\n",
239
+ " 'actions_already_taken': ['query_regulations'],\n",
240
+ " 'setup_actions': [{'action_type': 'query_regulations', 'reasoning': 'policy lookup'}]},\n",
241
+ " # task_4 targeting — audit ready\n",
242
+ " {'task_id': 'task_4_targeting',\n",
243
+ " 'text': 'Financial ad targeting minors. Policy, history, and ID verification done. Submit audit.',\n",
244
+ " 'actions_already_taken': ['query_regulations', 'check_advertiser_history', 'request_id_verification'],\n",
245
+ " 'setup_actions': [\n",
246
+ " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n",
247
+ " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n",
248
+ " {'action_type': 'request_id_verification', 'reasoning': 'age check'},\n",
249
+ " ]},\n",
250
+ " {'task_id': 'task_1_healthcare',\n",
251
+ " 'text': 'Healthcare ad. Policy already queried.',\n",
252
+ " 'actions_already_taken': ['query_regulations'],\n",
253
+ " 'setup_actions': [{'action_type': 'query_regulations', 'reasoning': 'policy lookup'}]},\n",
254
+ " {'task_id': 'task_2_financial',\n",
255
+ " 'text': 'Financial ad. Policy and history checked. Submit audit next.',\n",
256
+ " 'actions_already_taken': ['query_regulations', 'check_advertiser_history'],\n",
257
+ " 'setup_actions': [\n",
258
+ " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n",
259
+ " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n",
260
+ " ]},\n",
261
+ " {'task_id': 'task_2_financial',\n",
262
+ " 'text': 'Financial ad. Policy, history, audit all complete. Make final decision.',\n",
263
+ " 'actions_already_taken': ['query_regulations', 'check_advertiser_history', 'submit_audit'],\n",
264
+ " 'setup_actions': [\n",
265
+ " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n",
266
+ " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n",
267
+ " {'action_type': 'submit_audit', 'reasoning': 'audit log'},\n",
268
+ " ]},\n",
269
+ "]\n",
270
+ "\n",
271
+ "PROMPT_TEMPLATE = \"\"\"You are an enterprise Ad Policy Compliance Agent.\n",
272
+ "\n",
273
+ "You MUST choose exactly ONE action_type from this list (any other value is invalid):\n",
274
+ "- query_regulations\n",
275
+ "- analyze_image\n",
276
+ "- check_advertiser_history\n",
277
+ "- request_landing_page\n",
278
+ "- request_id_verification\n",
279
+ "- submit_audit\n",
280
+ "- approve\n",
281
+ "- reject\n",
282
+ "\n",
283
+ "REQUIRED PHASE ORDER:\n",
284
+ "1. query_regulations -> always first\n",
285
+ "2. analyze_image / check_advertiser_history -> gather signals\n",
286
+ "3. submit_audit -> always before final decision\n",
287
+ "4. approve OR reject -> only after audit\n",
288
+ "\n",
289
+ "HARD RULES:\n",
290
+ "- NEVER repeat an action listed in `actions_already_taken`.\n",
291
+ "- Respond with ONLY a valid JSON object. No markdown, no prose.\n",
292
+ "\n",
293
+ "Required format:\n",
294
+ "{{\\\"action_type\\\": \\\"<one_of_the_actions_above>\\\", \\\"reasoning\\\": \\\"<short reason>\\\"}}\n",
295
+ "\n",
296
+ "Scenario: {text}\n",
297
+ "actions_already_taken: {actions_already_taken}\n",
298
+ "\n",
299
+ "Your next action?\"\"\"\n",
300
+ "\n",
301
+ "def build_dataset():\n",
302
+ " rows = []\n",
303
+ " for s in BASE_SCENARIOS:\n",
304
+ " prompt = PROMPT_TEMPLATE.format(\n",
305
+ " text=s['text'],\n",
306
+ " actions_already_taken=json.dumps(s['actions_already_taken']),\n",
307
+ " )\n",
308
+ " rows.append({'prompt': prompt, 'task_id': s['task_id'], 'setup_actions': s['setup_actions']})\n",
309
+ " return Dataset.from_list(rows * 10)\n",
310
+ "\n",
311
+ "dataset = build_dataset()\n",
312
+ "print(f'✅ Dataset: {len(dataset)} examples')"
313
+ ],
314
+ "execution_count": null,
315
+ "outputs": []
316
+ },
317
+ {
318
+ "cell_type": "markdown",
319
+ "metadata": {},
320
+ "source": [
321
+ "## Cell 7 — Reward Function"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "metadata": {},
327
+ "source": [
328
+ "# Track rewards for plotting\n",
329
+ "reward_log = []\n",
330
+ "step_log = []\n",
331
+ "global_step_counter = [0]\n",
332
+ "\n",
333
+ "def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):\n",
334
+ " client = EnvClient(ENV_URL)\n",
335
+ " rewards = []\n",
336
+ "\n",
337
+ " for completion, t_id, setup in zip(completions, task_id, setup_actions):\n",
338
+ " parsed = extract_json(completion)\n",
339
+ " if not parsed:\n",
340
+ " rewards.append(-1.0)\n",
341
+ " continue\n",
342
+ "\n",
343
+ " action_type = parsed.get('action_type')\n",
344
+ " if action_type not in ALLOWED_ACTIONS:\n",
345
+ " rewards.append(-1.0)\n",
346
+ " continue\n",
347
+ "\n",
348
+ " action = {\n",
349
+ " 'action_type': action_type,\n",
350
+ " 'reasoning': parsed.get('reasoning', 'format-compliant'),\n",
351
+ " }\n",
352
+ "\n",
353
+ " try:\n",
354
+ " client.reset(t_id)\n",
355
+ " for s in setup:\n",
356
+ " safe_step(client, s)\n",
357
+ "\n",
358
+ " result = safe_step(client, action)\n",
359
+ " env_reward = float(result.get('reward', -0.2))\n",
360
+ " status_msg = (result.get('status_message') or '').lower()\n",
361
+ "\n",
362
+ " rejected = (\n",
363
+ " 'api failure' in status_msg\n",
364
+ " or 'invalid action' in status_msg\n",
365
+ " or 'must call' in status_msg\n",
366
+ " )\n",
367
+ " shaped = -0.5 if rejected else 0.5 + env_reward\n",
368
+ " rewards.append(shaped)\n",
369
+ "\n",
370
+ " except Exception:\n",
371
+ " rewards.append(-0.3)\n",
372
+ "\n",
373
+ " # Log for plot\n",
374
+ " avg = sum(rewards) / len(rewards) if rewards else 0.0\n",
375
+ " global_step_counter[0] += 1\n",
376
+ " reward_log.append(avg)\n",
377
+ " step_log.append(global_step_counter[0])\n",
378
+ "\n",
379
+ " return rewards\n",
380
+ "\n",
381
+ "print('✅ Reward function ready')"
382
+ ],
383
+ "execution_count": null,
384
+ "outputs": []
385
+ },
386
+ {
387
+ "cell_type": "markdown",
388
+ "metadata": {},
389
+ "source": [
390
+ "## Cell 8 — Load Model"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "metadata": {},
396
+ "source": [
397
+ "if torch.cuda.is_available():\n",
398
+ " _props = torch.cuda.get_device_properties(0)\n",
399
+ " _vram = _props.total_memory\n",
400
+ " _cc = (_props.major, _props.minor)\n",
401
+ " print(f'GPU: {_props.name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}')\n",
402
+ "else:\n",
403
+ " _vram, _cc = 0, (0, 0)\n",
404
+ "\n",
405
+ "USE_4BIT = _vram < 40 * 1024**3 # T4/L4 → 4-bit; A100 → full precision\n",
406
+ "USE_BF16 = _cc >= (8, 0) and not USE_4BIT # bf16 only with full-precision weights; 4-bit LoRA uses fp16\n",
407
+ "print(f'4-bit: {USE_4BIT} bf16: {USE_BF16}')\n",
408
+ "\n",
409
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
410
+ " model_name='unsloth/Llama-3.1-8B-Instruct',\n",
411
+ " load_in_4bit=USE_4BIT,\n",
412
+ " max_seq_length=2048,\n",
413
+ " dtype=torch.float16 if USE_4BIT else None,\n",
414
+ ")\n",
415
+ "\n",
416
+ "model = FastLanguageModel.get_peft_model(\n",
417
+ " model,\n",
418
+ " r=32,\n",
419
+ " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
420
+ " lora_alpha=64,\n",
421
+ " lora_dropout=0,\n",
422
+ " bias='none',\n",
423
+ " use_gradient_checkpointing='unsloth',\n",
424
+ " random_state=3407,\n",
425
+ ")\n",
426
+ "print('✅ Model loaded')"
427
+ ],
428
+ "execution_count": null,
429
+ "outputs": []
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "metadata": {},
434
+ "source": [
435
+ "## Cell 9 — Train"
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "code",
440
+ "metadata": {},
441
+ "source": [
442
+ "trainer = GRPOTrainer(\n",
443
+ " model=model,\n",
444
+ " reward_funcs=[reward_environment],\n",
445
+ " args=GRPOConfig(\n",
446
+ " output_dir='outputs',\n",
447
+ " learning_rate=2e-5,\n",
448
+ " num_train_epochs=3,\n",
449
+ " per_device_train_batch_size=2 if not USE_4BIT else 1,\n",
450
+ " gradient_accumulation_steps=4,\n",
451
+ " num_generations=4 if not USE_4BIT else 2,\n",
452
+ " max_prompt_length=768,\n",
453
+ " max_completion_length=128,\n",
454
+ " logging_steps=5,\n",
455
+ " warmup_steps=10,\n",
456
+ " bf16=USE_BF16,\n",
457
+ " fp16=not USE_BF16,\n",
458
+ " report_to='none',\n",
459
+ " ),\n",
460
+ " train_dataset=dataset,\n",
461
+ " tokenizer=tokenizer,\n",
462
+ ")\n",
463
+ "\n",
464
+ "print('Starting GRPO training...')\n",
465
+ "print(f' bf16={USE_BF16} fp16={not USE_BF16} batch={2 if not USE_4BIT else 1} gens={4 if not USE_4BIT else 2}')\n",
466
+ "trainer.train()\n",
467
+ "print('Training complete')"
468
+ ],
469
+ "execution_count": null,
470
+ "outputs": []
471
+ },
472
+ {
473
+ "cell_type": "markdown",
474
+ "metadata": {},
475
+ "source": [
476
+ "## Cell 10 — Plot Reward Curve"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "metadata": {},
482
+ "source": [
483
+ "import matplotlib.pyplot as plt\n",
484
+ "import numpy as np\n",
485
+ "import pandas as pd\n",
486
+ "import os\n",
487
+ "\n",
488
+ "os.makedirs('outputs', exist_ok=True)\n",
489
+ "\n",
490
+ "def moving_avg(data, window=5):\n",
491
+ " if len(data) < window:\n",
492
+ " return data\n",
493
+ " return list(np.convolve(data, np.ones(window)/window, mode='valid'))\n",
494
+ "\n",
495
+ "hist = pd.DataFrame(trainer.state.log_history)\n",
496
+ "\n",
497
+ "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
498
+ "\n",
499
+ "# --- Plot 1: Reward curve (from our custom log) ---\n",
500
+ "ax = axes[0]\n",
501
+ "ax.plot(step_log, reward_log, alpha=0.3, color='steelblue', label='Raw')\n",
502
+ "smoothed = moving_avg(reward_log)\n",
503
+ "ax.plot(range(len(smoothed)), smoothed, color='steelblue', linewidth=2, label='Smoothed (MA-5)')\n",
504
+ "ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n",
505
+ "ax.set_xlabel('Reward Eval Step')\n",
506
+ "ax.set_ylabel('Avg Reward per Batch')\n",
507
+ "ax.set_title('Reward Curve')\n",
508
+ "ax.legend()\n",
509
+ "ax.grid(alpha=0.3)\n",
510
+ "\n",
511
+ "# --- Plot 2: Loss curve (from trainer logs) ---\n",
512
+ "ax = axes[1]\n",
513
+ "loss_rows = hist.dropna(subset=['loss']) if 'loss' in hist.columns else pd.DataFrame()\n",
514
+ "if not loss_rows.empty:\n",
515
+ " ax.plot(loss_rows['step'], loss_rows['loss'], color='#7c3aed', linewidth=2)\n",
516
+ " ax.set_xlabel('Training Step')\n",
517
+ " ax.set_ylabel('Loss')\n",
518
+ " ax.set_title('GRPO Loss')\n",
519
+ " ax.grid(alpha=0.3)\n",
520
+ "else:\n",
521
+ " ax.text(0.5, 0.5, 'No loss data logged', ha='center', va='center', transform=ax.transAxes)\n",
522
+ " ax.set_title('GRPO Loss')\n",
523
+ "\n",
524
+ "# --- Plot 3: Reward from trainer logs (if available) ---\n",
525
+ "ax = axes[2]\n",
526
+ "reward_cols = [c for c in hist.columns if 'reward' in c.lower() and 'std' not in c.lower()]\n",
527
+ "if reward_cols:\n",
528
+ " col = reward_cols[0]\n",
529
+ " rr = hist.dropna(subset=[col])\n",
530
+ " ax.plot(rr['step'], rr[col], color='#16a34a', linewidth=2)\n",
531
+ " ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n",
532
+ " ax.set_xlabel('Training Step')\n",
533
+ " ax.set_ylabel(col)\n",
534
+ " ax.set_title('Trainer Reward Log')\n",
535
+ " ax.grid(alpha=0.3)\n",
536
+ "else:\n",
537
+ " ax.text(0.5, 0.5, 'No trainer reward data', ha='center', va='center', transform=ax.transAxes)\n",
538
+ " ax.set_title('Trainer Reward Log')\n",
539
+ "\n",
540
+ "plt.tight_layout()\n",
541
+ "plt.savefig('outputs/training_plots.png', dpi=150)\n",
542
+ "plt.show()\n",
543
+ "print('Saved to outputs/training_plots.png')\n",
544
+ "\n",
545
+ "n = len(reward_log)\n",
546
+ "first_10 = reward_log[:min(10, n)]\n",
547
+ "last_10 = reward_log[max(0, n-10):]\n",
548
+ "print(f'\\n--- Before vs After ---')\n",
549
+ "print(f'Avg reward (first 10 steps): {sum(first_10)/len(first_10):.3f}')\n",
550
+ "print(f'Avg reward (last 10 steps) : {sum(last_10)/len(last_10):.3f}')"
551
+ ],
552
+ "execution_count": null,
553
+ "outputs": []
554
+ },
555
+ {
556
+ "cell_type": "markdown",
557
+ "metadata": {},
558
+ "source": [
559
+ "## Cell 11 — Before vs After: Baseline Comparison"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "metadata": {},
565
+ "source": [
566
+ "from unsloth import FastLanguageModel as FLM\n",
567
+ "\n",
568
+ "FLM.for_inference(model)\n",
569
+ "\n",
570
+ "test_scenarios = [\n",
571
+ " ('task_1_healthcare', \"Healthcare ad: 'miracle cure'. No actions taken yet.\", []),\n",
572
+ " ('task_2_financial', \"Financial ad: 'guaranteed returns'. No actions taken yet.\", []),\n",
573
+ " ('task_4_targeting', \"Financial ad targeting teens. No actions taken yet.\", []),\n",
574
+ " ('task_2_financial', \"Financial ad. Policy, history, audit done. Decide.\",\n",
575
+ " ['query_regulations', 'check_advertiser_history', 'submit_audit']),\n",
576
+ "]\n",
577
+ "\n",
578
+ "print('=== Trained Model Outputs ===\\n')\n",
579
+ "for task, text, taken in test_scenarios:\n",
580
+ " prompt = PROMPT_TEMPLATE.format(text=text, actions_already_taken=json.dumps(taken))\n",
581
+ " inputs = tokenizer(prompt, return_tensors='pt').to('cuda')\n",
582
+ " out = model.generate(**inputs, max_new_tokens=64, temperature=0.1, do_sample=True)\n",
583
+ " decoded = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
584
+ " parsed = extract_json(decoded) or decoded.strip()[:120]\n",
585
+ " print(f'[{task}] taken={taken}')\n",
586
+ " print(f' -> {parsed}\\n')"
587
+ ],
588
+ "execution_count": null,
589
+ "outputs": []
590
+ },
591
+ {
592
+ "cell_type": "markdown",
593
+ "metadata": {},
594
+ "source": [
595
+ "## Cell 12 — Save + Push to HF Hub"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "code",
600
+ "metadata": {},
601
+ "source": [
602
+ "model.save_pretrained('outputs/lora_adapter')\n",
603
+ "tokenizer.save_pretrained('outputs/lora_adapter')\n",
604
+ "print('LoRA adapter saved')\n",
605
+ "\n",
606
+ "print('Merging adapter into base model...')\n",
607
+ "merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n",
608
+ " model_name='outputs/lora_adapter',\n",
609
+ " load_in_4bit=False,\n",
610
+ " max_seq_length=2048,\n",
611
+ ")\n",
612
+ "merged_model.save_pretrained_merged(\n",
613
+ " 'outputs/merged',\n",
614
+ " merged_tokenizer,\n",
615
+ " save_method='merged_16bit',\n",
616
+ ")\n",
617
+ "print('Merged model saved to outputs/merged')\n",
618
+ "\n",
619
+ "if HF_REPO and HF_TOKEN:\n",
620
+ " print(f'Pushing to {HF_REPO}...')\n",
621
+ " merged_model.push_to_hub_merged(\n",
622
+ " HF_REPO,\n",
623
+ " merged_tokenizer,\n",
624
+ " save_method='merged_16bit',\n",
625
+ " token=HF_TOKEN,\n",
626
+ " )\n",
627
+ " print(f'Model live at https://huggingface.co/{HF_REPO}')\n",
628
+ "else:\n",
629
+ " print('Set HF_REPO and HF_TOKEN in Cell 3 to push to Hub')\n",
630
+ "\n",
631
+ "print('Done.')"
632
+ ],
633
+ "execution_count": null,
634
+ "outputs": []
635
+ }
636
+ ],
637
+ "metadata": {
638
+ "colab": {
639
+ "provenance": [],
640
+ "gpuType": "A100"
641
+ },
642
+ "kernelspec": {
643
+ "display_name": "Python 3",
644
+ "name": "python3"
645
+ },
646
+ "language_info": {
647
+ "name": "python"
648
+ },
649
+ "accelerator": "GPU"
650
  },
651
+ "nbformat": 4,
652
+ "nbformat_minor": 0
653
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
grpo_train.py CHANGED
@@ -13,6 +13,16 @@ from trl import GRPOTrainer, GRPOConfig
13
 
14
  PatchFastRL("GRPO", FastLanguageModel)
15
 
 
 
 
 
 
 
 
 
 
 
16
  # =========================
17
  # CONFIG
18
  # =========================
@@ -37,7 +47,10 @@ ALLOWED_ACTIONS = [
37
  # =========================
38
 
39
  def ensure_env_ready():
40
- for _ in range(20):
 
 
 
41
  try:
42
  r = requests.post(
43
  f"{ENV_URL}/reset",
@@ -45,11 +58,20 @@ def ensure_env_ready():
45
  timeout=5
46
  )
47
  if r.status_code == 200:
 
 
 
48
  print("✅ Environment ready")
49
  return
50
- except:
 
 
 
51
  pass
52
  time.sleep(1)
 
 
 
53
  raise RuntimeError("❌ ENV not reachable")
54
 
55
  # =========================
@@ -240,21 +262,39 @@ def build_dataset():
240
  # REWARD FUNCTION (FIXED)
241
  # =========================
242
 
 
 
243
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
244
- """Shaped reward for GRPO.
245
-
246
- Pure env reward is too sparse (mostly -0.05) to give clear gradients.
247
- We add explicit shaping:
248
- - invalid JSON / invalid action_type -> -1.0 (strong negative signal)
249
- - valid action env REJECTS (wrong phase / API failure) -> -0.5
250
- - valid action env ACCEPTS (advances state) -> +0.5 + env_reward
251
- - terminal correct decision -> env_reward already contains +1.0 bonus
252
- """
 
 
 
 
 
 
 
253
  client = EnvClient(ENV_URL)
254
  rewards = []
255
 
256
- for completion, t_id, setup in zip(completions, task_id, setup_actions):
 
 
 
 
 
 
257
  parsed = extract_json(completion)
 
 
 
258
  if not parsed:
259
  rewards.append(-1.0)
260
  continue
@@ -300,23 +340,39 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
300
  # MODEL
301
  # =========================
302
 
303
- USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 40 * 1024**3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  model, tokenizer = FastLanguageModel.from_pretrained(
306
  model_name="unsloth/Llama-3.1-8B-Instruct",
307
  load_in_4bit=USE_4BIT,
308
  max_seq_length=2048,
309
- dtype=None, # auto-detect bf16 on A100
310
  )
311
 
312
  model = FastLanguageModel.get_peft_model(
313
  model,
314
- r=32,
315
  target_modules=[
316
  "q_proj", "k_proj", "v_proj", "o_proj",
317
  "gate_proj", "up_proj", "down_proj",
318
  ],
319
- lora_alpha=64,
320
  lora_dropout=0,
321
  bias="none",
322
  use_gradient_checkpointing="unsloth",
@@ -329,21 +385,26 @@ model = FastLanguageModel.get_peft_model(
329
 
330
  dataset = build_dataset()
331
 
 
 
 
 
332
  trainer = GRPOTrainer(
333
  model=model,
334
  reward_funcs=[reward_environment],
335
  args=GRPOConfig(
336
  output_dir="outputs",
337
  learning_rate=2e-5,
338
- num_train_epochs=3,
339
- per_device_train_batch_size=2,
340
- gradient_accumulation_steps=4,
341
- num_generations=4,
342
  max_prompt_length=768,
343
  max_completion_length=128,
344
- logging_steps=5,
345
- warmup_ratio=0.1,
346
- bf16=True,
 
347
  report_to="none",
348
  ),
349
  train_dataset=dataset,
@@ -357,6 +418,9 @@ trainer = GRPOTrainer(
357
  if __name__ == "__main__":
358
  ensure_env_ready()
359
 
 
 
 
360
  print("Starting GRPO training...")
361
  trainer.train()
362
 
 
13
 
14
  PatchFastRL("GRPO", FastLanguageModel)
15
 
16
+ # #region agent log
17
+ import pathlib as _pl
18
+ _DLOG = _pl.Path("debug-851b5f.log")
19
+ def _dlog(hyp, loc, msg, data=None):
20
+ import time as _t
21
+ entry = json.dumps({"sessionId":"851b5f","hypothesisId":hyp,"location":loc,"message":msg,"data":data or {},"timestamp":int(_t.time()*1000)})
22
+ with open(_DLOG, "a") as f: f.write(entry + "\n")
23
+ print(f"[DBG:{hyp}] {msg} {data or ''}", flush=True)
24
+ # #endregion
25
+
26
  # =========================
27
  # CONFIG
28
  # =========================
 
47
  # =========================
48
 
49
  def ensure_env_ready():
50
+ # #region agent log
51
+ _dlog("B", "grpo_train.py:ensure_env_ready", "Checking env", {"ENV_URL": ENV_URL})
52
+ # #endregion
53
+ for i in range(20):
54
  try:
55
  r = requests.post(
56
  f"{ENV_URL}/reset",
 
58
  timeout=5
59
  )
60
  if r.status_code == 200:
61
+ # #region agent log
62
+ _dlog("B", "grpo_train.py:ensure_env_ready", "Env ready", {"attempt": i+1, "status": r.status_code})
63
+ # #endregion
64
  print("✅ Environment ready")
65
  return
66
+ except Exception as e:
67
+ # #region agent log
68
+ if i == 0: _dlog("B", "grpo_train.py:ensure_env_ready", "Env connection failed", {"attempt": i+1, "error": str(e)[:200]})
69
+ # #endregion
70
  pass
71
  time.sleep(1)
72
+ # #region agent log
73
+ _dlog("B", "grpo_train.py:ensure_env_ready", "ENV UNREACHABLE after 20 attempts", {})
74
+ # #endregion
75
  raise RuntimeError("❌ ENV not reachable")
76
 
77
  # =========================
 
262
  # REWARD FUNCTION (FIXED)
263
  # =========================
264
 
265
+ _reward_call_count = [0]
266
+
267
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
268
+ """Shaped reward for GRPO."""
269
+ _reward_call_count[0] += 1
270
+ _call = _reward_call_count[0]
271
+ # #region agent log
272
+ _dlog("C", "grpo_train.py:reward_env", f"reward call #{_call}", {
273
+ "n_prompts": len(prompts) if prompts else 0,
274
+ "n_completions": len(completions) if completions else 0,
275
+ "completions_type": type(completions).__name__,
276
+ "first_completion_type": type(completions[0]).__name__ if completions else "N/A",
277
+ "first_completion_preview": str(completions[0])[:150] if completions else "N/A",
278
+ "task_id_is_none": task_id is None,
279
+ "setup_actions_is_none": setup_actions is None,
280
+ "kwargs_keys": list(kwargs.keys()),
281
+ })
282
+ # #endregion
283
+
284
  client = EnvClient(ENV_URL)
285
  rewards = []
286
 
287
+ if task_id is None or setup_actions is None:
288
+ # #region agent log
289
+ _dlog("D", "grpo_train.py:reward_env", "task_id or setup_actions is None — returning -1 for all", {"call": _call})
290
+ # #endregion
291
+ return [-1.0] * len(completions)
292
+
293
+ for idx, (completion, t_id, setup) in enumerate(zip(completions, task_id, setup_actions)):
294
  parsed = extract_json(completion)
295
+ # #region agent log
296
+ if _call <= 3: _dlog("D", "grpo_train.py:reward_loop", f"call#{_call} item#{idx}", {"parsed_ok": parsed is not None, "action": parsed.get("action_type") if parsed else None, "raw_preview": str(completion)[:120], "task_id": t_id})
297
+ # #endregion
298
  if not parsed:
299
  rewards.append(-1.0)
300
  continue
 
340
  # MODEL
341
  # =========================
342
 
343
+ if torch.cuda.is_available():
344
+ _props = torch.cuda.get_device_properties(0)
345
+ _vram = _props.total_memory
346
+ _name = _props.name
347
+ _cc = (_props.major, _props.minor) # compute capability
348
+ print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}")
349
+ else:
350
+ _vram = 0
351
+ _name = "CPU"
352
+ _cc = (0, 0)
353
+
354
+ USE_4BIT = _vram < 40 * 1024**3 # T4 (15 GB), L4 (24 GB) → 4-bit; A100 (80 GB) → full
355
+ USE_BF16 = _cc >= (8, 0) and not USE_4BIT # bf16 only when full-precision; 4-bit LoRA uses fp16 internally
356
+
357
+ # #region agent log
358
+ _dlog("A", "grpo_train.py:gpu_detect", "GPU config resolved", {"name":_name,"vram_gb":round(_vram/1024**3,1),"cc":list(_cc),"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16})
359
+ # #endregion
360
 
361
  model, tokenizer = FastLanguageModel.from_pretrained(
362
  model_name="unsloth/Llama-3.1-8B-Instruct",
363
  load_in_4bit=USE_4BIT,
364
  max_seq_length=2048,
365
+ dtype=torch.float16 if USE_4BIT else None,
366
  )
367
 
368
  model = FastLanguageModel.get_peft_model(
369
  model,
370
+ r=16 if USE_4BIT else 32,
371
  target_modules=[
372
  "q_proj", "k_proj", "v_proj", "o_proj",
373
  "gate_proj", "up_proj", "down_proj",
374
  ],
375
+ lora_alpha=32 if USE_4BIT else 64,
376
  lora_dropout=0,
377
  bias="none",
378
  use_gradient_checkpointing="unsloth",
 
385
 
386
  dataset = build_dataset()
387
 
388
+ # #region agent log
389
+ _dlog("A", "grpo_train.py:trainer_init", "Creating GRPOTrainer", {"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16,"epochs":1 if USE_4BIT else 3,"batch":1 if USE_4BIT else 2,"gens":2 if USE_4BIT else 4,"dataset_len":len(dataset)})
390
+ # #endregion
391
+
392
  trainer = GRPOTrainer(
393
  model=model,
394
  reward_funcs=[reward_environment],
395
  args=GRPOConfig(
396
  output_dir="outputs",
397
  learning_rate=2e-5,
398
+ num_train_epochs=1 if USE_4BIT else 3,
399
+ per_device_train_batch_size=1 if USE_4BIT else 2,
400
+ gradient_accumulation_steps=2 if USE_4BIT else 4,
401
+ num_generations=2 if USE_4BIT else 4,
402
  max_prompt_length=768,
403
  max_completion_length=128,
404
+ logging_steps=3 if USE_4BIT else 5,
405
+ warmup_steps=5 if USE_4BIT else 10,
406
+ bf16=USE_BF16,
407
+ fp16=not USE_BF16,
408
  report_to="none",
409
  ),
410
  train_dataset=dataset,
 
418
  if __name__ == "__main__":
419
  ensure_env_ready()
420
 
421
+ # #region agent log
422
+ _dlog("E", "grpo_train.py:train_start", "About to call trainer.train()", {"gpu_mem_allocated_gb": round(torch.cuda.memory_allocated()/1024**3, 2) if torch.cuda.is_available() else 0})
423
+ # #endregion
424
  print("Starting GRPO training...")
425
  trainer.train()
426