parth-1 commited on
Commit
c0ec9d7
·
verified ·
1 Parent(s): bd74b58

Upload grpo_train (1).ipynb

Browse files
Files changed (1) hide show
  1. grpo_train (1).ipynb +777 -0
grpo_train (1).ipynb ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🛡️ MetaGuard — GRPO Training Notebook\n",
8
+ "\n",
9
+ "**Team:** Parth Singhal, Mehakveer Kaur, Kartik Goyal \n",
10
+ "**HF Space:** https://huggingface.co/spaces/parth-1/MetaGuard \n",
11
+ "**Hackathon:** OpenEnv — Meta × Scaler \n",
12
+ "\n",
13
+ "This notebook trains **Llama 3.1 8B** using GRPO on the MetaGuard Ad Policy Compliance environment.\n",
14
+ "\n",
15
+ "### What this trains:\n",
16
+ "- Agent learns to follow structured SOP: `query_regulations → gather signals → submit_audit → decide`\n",
17
+ "- Reward shaped by correctness, sequence compliance, API failure recovery\n",
18
+ "- Environment runs locally in the notebook (fast); GPU handles only the model"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## Cell 1 — Install Dependencies"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "metadata": {},
31
+ "source": [
32
+ "!pip install unsloth trl transformers datasets accelerate peft -q\n",
33
+ "!pip install openenv-core==0.2.1 --no-deps -q\n",
34
+ "!pip install fastapi uvicorn pydantic requests openai matplotlib -q\n",
35
+ "print('✅ Dependencies installed')"
36
+ ],
37
+ "execution_count": null,
38
+ "outputs": []
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "## Cell 2 — Clone Repo"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "metadata": {},
50
+ "source": [
51
+ "import os\n",
52
+ "\n",
53
+ "if not os.path.exists('meta-ad-policy-sandbox'):\n",
54
+ " !git clone https://github.com/Parth380/meta-ad-policy-sandbox.git\n",
55
+ "\n",
56
+ "%cd meta-ad-policy-sandbox\n",
57
+ "!pip install -e . -q\n",
58
+ "os.makedirs('outputs', exist_ok=True)\n",
59
+ "print('Repo installed & outputs/ ready')"
60
+ ],
61
+ "execution_count": null,
62
+ "outputs": []
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "metadata": {},
67
+ "source": [
68
+ "## Cell 3 — Config (SET THESE)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "metadata": {},
74
+ "source": [
75
+ "import os\n",
76
+ "\n",
77
+ "os.environ['ENV_URL'] = 'http://localhost:8000' # local env (fast); change to HF Space URL if needed\n",
78
+ "os.environ['HF_REPO'] = 'parth-1/metaguard-llama3.1-8b-grpo'\n",
79
+ "os.environ['HF_TOKEN'] = '' # paste your HF write token here\n",
80
+ "\n",
81
+ "ENV_URL = os.environ['ENV_URL']\n",
82
+ "HF_TOKEN = os.environ['HF_TOKEN']\n",
83
+ "HF_REPO = os.environ['HF_REPO']\n",
84
+ "\n",
85
+ "print(f'ENV_URL : {ENV_URL}')\n",
86
+ "print(f'HF_REPO : {HF_REPO}')\n",
87
+ "print(f'HF_TOKEN : {\"set\" if HF_TOKEN else \"MISSING -- set above before Cell 11\"}')"
88
+ ],
89
+ "execution_count": null,
90
+ "outputs": []
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {},
95
+ "source": [
96
+ "## Cell 4 — Boot Local Environment"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "metadata": {},
102
+ "source": [
103
+ "import os\n",
104
+ "os.environ.setdefault('USER', 'user')\n",
105
+ "\n",
106
+ "import subprocess, time, threading, requests\n",
107
+ "import uvicorn\n",
108
+ "\n",
109
+ "procs = [\n",
110
+ " subprocess.Popen(['python', 'apps/regulatory_api.py']),\n",
111
+ " subprocess.Popen(['python', 'apps/crm_api.py']),\n",
112
+ " subprocess.Popen(['python', 'apps/audit_api.py']),\n",
113
+ "]\n",
114
+ "time.sleep(3)\n",
115
+ "\n",
116
+ "from server.app import app as _env_app\n",
117
+ "threading.Thread(\n",
118
+ " target=uvicorn.run,\n",
119
+ " kwargs={'app': _env_app, 'host': '0.0.0.0', 'port': 8000, 'log_level': 'warning'},\n",
120
+ " daemon=True,\n",
121
+ ").start()\n",
122
+ "time.sleep(2)\n",
123
+ "\n",
124
+ "for i in range(20):\n",
125
+ " try:\n",
126
+ " r = requests.post(f'{ENV_URL}/reset', json={'task_id': 'task_1_healthcare'}, timeout=5)\n",
127
+ " if r.status_code == 200:\n",
128
+ " print(f'Environment ready (attempt {i+1})')\n",
129
+ " break\n",
130
+ " except:\n",
131
+ " pass\n",
132
+ " time.sleep(1)\n",
133
+ "else:\n",
134
+ " raise RuntimeError('ENV not reachable after 20 attempts')"
135
+ ],
136
+ "execution_count": null,
137
+ "outputs": []
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "## Cell 5 — Imports + Helpers"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "metadata": {},
149
+ "source": [
150
+ "import json\n",
151
+ "import random\n",
152
+ "import torch\n",
153
+ "import matplotlib.pyplot as plt\n",
154
+ "from collections import defaultdict\n",
155
+ "\n",
156
+ "from datasets import Dataset\n",
157
+ "from unsloth import FastLanguageModel, PatchFastRL\n",
158
+ "from trl import GRPOTrainer, GRPOConfig\n",
159
+ "\n",
160
+ "PatchFastRL('GRPO', FastLanguageModel)\n",
161
+ "\n",
162
+ "ALLOWED_ACTIONS = [\n",
163
+ " 'query_regulations', 'analyze_image', 'check_advertiser_history',\n",
164
+ " 'request_landing_page', 'request_id_verification',\n",
165
+ " 'submit_audit', 'approve', 'reject',\n",
166
+ "]\n",
167
+ "\n",
168
+ "class EnvClient:\n",
169
+ " def __init__(self, url):\n",
170
+ " self.url = url\n",
171
+ " def reset(self, task_id):\n",
172
+ " return requests.post(f'{self.url}/reset', json={'task_id': task_id}, timeout=8).json()\n",
173
+ " def step(self, action):\n",
174
+ " return requests.post(f'{self.url}/step', json={'action': action}, timeout=8).json()\n",
175
+ "\n",
176
+ "def safe_step(client, action):\n",
177
+ " for _ in range(3):\n",
178
+ " try:\n",
179
+ " return client.step(action)\n",
180
+ " except:\n",
181
+ " time.sleep(0.5)\n",
182
+ " return {'reward': -0.3}\n",
183
+ "\n",
184
+ "def extract_json(text):\n",
185
+ " try:\n",
186
+ " if '```' in text:\n",
187
+ " text = text.split('```')[1]\n",
188
+ " if text.startswith('json'):\n",
189
+ " text = text[4:]\n",
190
+ " return json.loads(text.strip())\n",
191
+ " except:\n",
192
+ " return None\n",
193
+ "\n",
194
+ "print('✅ Helpers loaded')"
195
+ ],
196
+ "execution_count": null,
197
+ "outputs": []
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "metadata": {},
202
+ "source": [
203
+ "## Cell 6 — Dataset"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "metadata": {},
209
+ "source": [
210
+ "SYSTEM_PROMPT = (\n",
211
+ " \"You are an enterprise Ad Policy Compliance Agent.\\n\"\n",
212
+ " \"You navigate a multi-system compliance workflow. Always respond with ONLY valid JSON.\\n\"\n",
213
+ " \"\\n\"\n",
214
+ " \"REQUIRED PHASE ORDER:\\n\"\n",
215
+ " \"1. query_regulations \\u2014 always first\\n\"\n",
216
+ " \"2. analyze_image \\u2014 required for visual/multimodal tasks\\n\"\n",
217
+ " \"3. check_advertiser_history or request_landing_page \\u2014 as needed\\n\"\n",
218
+ " \"4. submit_audit \\u2014 always before final decision\\n\"\n",
219
+ " \"5. approve or reject \\u2014 final decision only after audit\\n\"\n",
220
+ " \"\\n\"\n",
221
+ " \"AVAILABLE ACTIONS:\\n\"\n",
222
+ " \"- query_regulations\\n- analyze_image\\n- check_advertiser_history\\n\"\n",
223
+ " \"- request_landing_page\\n- request_id_verification\\n- submit_audit\\n\"\n",
224
+ " \"- approve\\n- reject\\n\"\n",
225
+ " \"\\n\"\n",
226
+ " \"HARD RULES:\\n\"\n",
227
+ " \"- NEVER repeat an action listed in `actions_already_taken`.\\n\"\n",
228
+ " \"- You MUST progress through the phase order. Do NOT call submit_audit or approve/reject\\n\"\n",
229
+ " \" before the prerequisite phases are complete.\\n\"\n",
230
+ " \"- Choose your action_type ONLY from the AVAILABLE ACTIONS list above. Any other value is invalid.\\n\"\n",
231
+ " \"\\n\"\n",
232
+ " 'Response format:\\n'\n",
233
+ " '{\"action_type\": \"<action>\", \"reasoning\": \"<brief reason>\"}'\n",
234
+ ")\n",
235
+ "\n",
236
+ "TASK_AD_DATA = {\n",
237
+ " 'task_1_healthcare': {'ad_id':'hc-001','headline':'Secret D1abet3s Cure','body_text':'Reverse it today with this root extract!','advertiser_trust_score':0.25,'targeting_data':{'age_min':25,'geo':'US'},'image_url':'https://mock-meta.com/img/hc-001.jpg'},\n",
238
+ " 'task_1_healthcare_b': {'ad_id':'hc-002','headline':'Optimize Metabolic Markers','body_text':'Our proprietary GLP-1 analog ships directly to your door without a visit.','advertiser_trust_score':0.35,'targeting_data':{'age_min':30,'geo':'US'},'image_url':'https://mock-meta.com/img/hc-002.jpg'},\n",
239
+ " 'task_2_financial': {'ad_id':'fn-001','headline':'Garanteed 500% Returns','body_text':'Invest in LlamaCoin now. Zero risk.','advertiser_trust_score':0.20,'targeting_data':{'age_min':21,'geo':'US'},'image_url':'https://mock-meta.com/img/fn-001.jpg'},\n",
240
+ " 'task_2_financial_b': {'ad_id':'fn-002','headline':'Beat the System','body_text':'Leverage our algorithmic trading bot to exploit market inefficiencies.','advertiser_trust_score':0.30,'targeting_data':{'age_min':25,'geo':'US'},'image_url':'https://mock-meta.com/img/fn-002.jpg'},\n",
241
+ " 'task_3_multimodal': {'ad_id':'mm-001','headline':'Achieve Your Dream Body','body_text':'Hard work and our new program will get you there.','advertiser_trust_score':0.90,'targeting_data':{'age_min':18,'geo':'US'},'image_url':'https://mock-meta.com/img/mm-001.jpg'},\n",
242
+ " 'task_3_multimodal_b': {'ad_id':'mm-002','headline':'Stay Focused All Day','body_text':'Our study supplement helps you crush your exams.','advertiser_trust_score':0.88,'targeting_data':{'age_min':18,'geo':'US'},'image_url':'https://mock-meta.com/img/mm-002.jpg'},\n",
243
+ " 'task_4_targeting': {'ad_id':'tg-001','headline':'Start Your First Investment Portfolio','body_text':'Learn the basics of stock trading with our beginner-friendly app.','advertiser_trust_score':0.72,'targeting_data':{'age_min':15,'geo':'US'},'image_url':'https://mock-meta.com/img/tg-001.jpg'},\n",
244
+ " 'task_6_conflict': {'ad_id':'cf-001','headline':'Smart Investment Tools','body_text':'Our platform helps you manage your portfolio with ease.','advertiser_trust_score':0.92,'targeting_data':{'age_min':21,'geo':'US'},'image_url':'https://mock-meta.com/img/cf-001.jpg'},\n",
245
+ " 'task_7_ambiguous': {'ad_id':'am-001','headline':'Feel Better Naturally','body_text':'Our wellness program supports your health journey.','advertiser_trust_score':0.55,'targeting_data':{'age_min':22,'geo':'US'},'image_url':'https://mock-meta.com/img/am-001.jpg'},\n",
246
+ " 'task_8_adversarial': {'ad_id':'av-001','headline':'Natural Energy Boost','body_text':'Feel amazing every day with our daily supplement.','advertiser_trust_score':0.40,'targeting_data':{'age_min':20,'geo':'US'},'image_url':'https://mock-meta.com/img/av-001.jpg'},\n",
247
+ " 'task_9_dependency_trap': {'ad_id':'dt-001','headline':'Wellness Support Pack','body_text':'Certified vitamins to support your daily routine.','advertiser_trust_score':0.85,'targeting_data':{'age_min':25,'geo':'US'},'image_url':'https://mock-meta.com/img/dt-001.jpg'},\n",
248
+ " 'task_10_failure': {'ad_id':'fr-001','headline':'Miracle Weight Loss Patch','body_text':'Lose 10kg in 2 weeks. Guaranteed results or money back.','advertiser_trust_score':0.15,'targeting_data':{'age_min':22,'geo':'US'},'image_url':'https://mock-meta.com/img/fr-001.jpg'},\n",
249
+ "}\n",
250
+ "\n",
251
+ "_sa = lambda *acts: [{'action_type': a, 'reasoning': 'setup'} for a in acts]\n",
252
+ "\n",
253
+ "BASE_SCENARIOS = [\n",
254
+ " # Task 1: Healthcare\n",
255
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_1_healthcare. Begin with query_regulations.','signals':{}},\n",
256
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare_b','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_1_healthcare. Begin with query_regulations.','signals':{}},\n",
257
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.92','signals':{'policy_confidence':0.92}},\n",
258
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare_b','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.78','signals':{'policy_confidence':0.78}},\n",
259
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare','step_count':3,'actions_already_taken':['query_regulations','check_advertiser_history'],'setup_actions':_sa('query_regulations','check_advertiser_history'),'last_feedback':'risk_score=0.82','signals':{'policy_confidence':0.92,'risk_score':0.82}},\n",
260
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare','step_count':4,'actions_already_taken':['query_regulations','check_advertiser_history','submit_audit'],'setup_actions':_sa('query_regulations','check_advertiser_history','submit_audit'),'last_feedback':'audit_logged id=AUD-001','signals':{'policy_confidence':0.92,'risk_score':0.82}},\n",
261
+ " # Task 2: Financial\n",
262
+ " {'task_id':'task_2_financial','ad_key':'task_2_financial','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_2_financial. Begin with query_regulations.','signals':{}},\n",
263
+ " {'task_id':'task_2_financial','ad_key':'task_2_financial_b','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_2_financial. Begin with query_regulations.','signals':{}},\n",
264
+ " {'task_id':'task_2_financial','ad_key':'task_2_financial','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.88','signals':{'policy_confidence':0.88}},\n",
265
+ " {'task_id':'task_2_financial','ad_key':'task_2_financial','step_count':3,'actions_already_taken':['query_regulations','check_advertiser_history'],'setup_actions':_sa('query_regulations','check_advertiser_history'),'last_feedback':'risk_score=0.75','signals':{'policy_confidence':0.88,'risk_score':0.75}},\n",
266
+ " {'task_id':'task_2_financial','ad_key':'task_2_financial','step_count':4,'actions_already_taken':['query_regulations','check_advertiser_history','submit_audit'],'setup_actions':_sa('query_regulations','check_advertiser_history','submit_audit'),'last_feedback':'audit_logged id=AUD-002','signals':{'policy_confidence':0.88,'risk_score':0.75}},\n",
267
+ " # Task 3: Multimodal\n",
268
+ " {'task_id':'task_3_multimodal','ad_key':'task_3_multimodal','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_3_multimodal. Begin with query_regulations.','signals':{}},\n",
269
+ " {'task_id':'task_3_multimodal','ad_key':'task_3_multimodal_b','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_3_multimodal. Begin with query_regulations.','signals':{}},\n",
270
+ " {'task_id':'task_3_multimodal','ad_key':'task_3_multimodal','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.65','signals':{'policy_confidence':0.65}},\n",
271
+ " {'task_id':'task_3_multimodal','ad_key':'task_3_multimodal','step_count':3,'actions_already_taken':['query_regulations','analyze_image'],'setup_actions':_sa('query_regulations','analyze_image'),'last_feedback':'image_violation_detected','signals':{'policy_confidence':0.65,'image_flag':True}},\n",
272
+ " {'task_id':'task_3_multimodal','ad_key':'task_3_multimodal','step_count':4,'actions_already_taken':['query_regulations','analyze_image','check_advertiser_history'],'setup_actions':_sa('query_regulations','analyze_image','check_advertiser_history'),'last_feedback':'risk_score=0.45','signals':{'policy_confidence':0.65,'image_flag':True,'risk_score':0.45}},\n",
273
+ " {'task_id':'task_3_multimodal','ad_key':'task_3_multimodal','step_count':5,'actions_already_taken':['query_regulations','analyze_image','check_advertiser_history','submit_audit'],'setup_actions':_sa('query_regulations','analyze_image','check_advertiser_history','submit_audit'),'last_feedback':'audit_logged id=AUD-003','signals':{'policy_confidence':0.65,'image_flag':True,'risk_score':0.45}},\n",
274
+ " # Task 4: Targeting\n",
275
+ " {'task_id':'task_4_targeting','ad_key':'task_4_targeting','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_4_targeting. Begin with query_regulations.','signals':{}},\n",
276
+ " {'task_id':'task_4_targeting','ad_key':'task_4_targeting','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.70','signals':{'policy_confidence':0.70}},\n",
277
+ " {'task_id':'task_4_targeting','ad_key':'task_4_targeting','step_count':3,'actions_already_taken':['query_regulations','request_id_verification'],'setup_actions':_sa('query_regulations','request_id_verification'),'last_feedback':'ALERT: minor targeting age=15','signals':{'policy_confidence':0.70}},\n",
278
+ " {'task_id':'task_4_targeting','ad_key':'task_4_targeting','step_count':4,'actions_already_taken':['query_regulations','request_id_verification','check_advertiser_history'],'setup_actions':_sa('query_regulations','request_id_verification','check_advertiser_history'),'last_feedback':'risk_score=0.60','signals':{'policy_confidence':0.70,'risk_score':0.60}},\n",
279
+ " {'task_id':'task_4_targeting','ad_key':'task_4_targeting','step_count':5,'actions_already_taken':['query_regulations','request_id_verification','check_advertiser_history','submit_audit'],'setup_actions':_sa('query_regulations','request_id_verification','check_advertiser_history','submit_audit'),'last_feedback':'audit_logged id=AUD-004','signals':{'policy_confidence':0.70,'risk_score':0.60}},\n",
280
+ " # Task 6: Conflict\n",
281
+ " {'task_id':'task_6_conflict','ad_key':'task_6_conflict','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_6_conflict. Begin with query_regulations.','signals':{}},\n",
282
+ " {'task_id':'task_6_conflict','ad_key':'task_6_conflict','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.72','signals':{'policy_confidence':0.72}},\n",
283
+ " {'task_id':'task_6_conflict','ad_key':'task_6_conflict','step_count':3,'actions_already_taken':['query_regulations','check_advertiser_history'],'setup_actions':_sa('query_regulations','check_advertiser_history'),'last_feedback':'risk_score=0.78','signals':{'policy_confidence':0.72,'risk_score':0.78}},\n",
284
+ " {'task_id':'task_6_conflict','ad_key':'task_6_conflict','step_count':4,'actions_already_taken':['query_regulations','check_advertiser_history','submit_audit'],'setup_actions':_sa('query_regulations','check_advertiser_history','submit_audit'),'last_feedback':'audit_logged id=AUD-006','signals':{'policy_confidence':0.72,'risk_score':0.78}},\n",
285
+ " # Task 7: Ambiguous\n",
286
+ " {'task_id':'task_7_ambiguous','ad_key':'task_7_ambiguous','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_7_ambiguous. Begin with query_regulations.','signals':{}},\n",
287
+ " {'task_id':'task_7_ambiguous','ad_key':'task_7_ambiguous','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.42','signals':{'policy_confidence':0.42}},\n",
288
+ " {'task_id':'task_7_ambiguous','ad_key':'task_7_ambiguous','step_count':3,'actions_already_taken':['query_regulations','check_advertiser_history'],'setup_actions':_sa('query_regulations','check_advertiser_history'),'last_feedback':'risk_score=0.55','signals':{'policy_confidence':0.42,'risk_score':0.55}},\n",
289
+ " {'task_id':'task_7_ambiguous','ad_key':'task_7_ambiguous','step_count':4,'actions_already_taken':['query_regulations','check_advertiser_history','request_landing_page'],'setup_actions':_sa('query_regulations','check_advertiser_history','request_landing_page'),'last_feedback':'landing_suspicious','signals':{'policy_confidence':0.42,'risk_score':0.55,'landing_flag':True}},\n",
290
+ " {'task_id':'task_7_ambiguous','ad_key':'task_7_ambiguous','step_count':5,'actions_already_taken':['query_regulations','check_advertiser_history','request_landing_page','submit_audit'],'setup_actions':_sa('query_regulations','check_advertiser_history','request_landing_page','submit_audit'),'last_feedback':'audit_logged id=AUD-007','signals':{'policy_confidence':0.42,'risk_score':0.55,'landing_flag':True}},\n",
291
+ " # Task 8: Adversarial\n",
292
+ " {'task_id':'task_8_adversarial','ad_key':'task_8_adversarial','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_8_adversarial. Begin with query_regulations.','signals':{}},\n",
293
+ " {'task_id':'task_8_adversarial','ad_key':'task_8_adversarial','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.75','signals':{'policy_confidence':0.75}},\n",
294
+ " {'task_id':'task_8_adversarial','ad_key':'task_8_adversarial','step_count':3,'actions_already_taken':['query_regulations','analyze_image'],'setup_actions':_sa('query_regulations','analyze_image'),'last_feedback':'image_violation_detected','signals':{'policy_confidence':0.75,'image_flag':True}},\n",
295
+ " {'task_id':'task_8_adversarial','ad_key':'task_8_adversarial','step_count':4,'actions_already_taken':['query_regulations','analyze_image','submit_audit'],'setup_actions':_sa('query_regulations','analyze_image','submit_audit'),'last_feedback':'audit_logged id=AUD-008','signals':{'policy_confidence':0.75,'image_flag':True}},\n",
296
+ " # Task 9: Dependency Trap\n",
297
+ " {'task_id':'task_9_dependency_trap','ad_key':'task_9_dependency_trap','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_9_dependency_trap. Begin with query_regulations.','signals':{}},\n",
298
+ " {'task_id':'task_9_dependency_trap','ad_key':'task_9_dependency_trap','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.50','signals':{'policy_confidence':0.50}},\n",
299
+ " {'task_id':'task_9_dependency_trap','ad_key':'task_9_dependency_trap','step_count':3,'actions_already_taken':['query_regulations','analyze_image'],'setup_actions':_sa('query_regulations','analyze_image'),'last_feedback':'image_violation_detected','signals':{'policy_confidence':0.50,'image_flag':True}},\n",
300
+ " {'task_id':'task_9_dependency_trap','ad_key':'task_9_dependency_trap','step_count':4,'actions_already_taken':['query_regulations','analyze_image','submit_audit'],'setup_actions':_sa('query_regulations','analyze_image','submit_audit'),'last_feedback':'audit_logged id=AUD-009','signals':{'policy_confidence':0.50,'image_flag':True}},\n",
301
+ " # Task 10: Failure Recovery\n",
302
+ " {'task_id':'task_10_failure','ad_key':'task_10_failure','step_count':1,'actions_already_taken':[],'setup_actions':[],'last_feedback':'Ad loaded for task_10_failure. Begin with query_regulations.','signals':{}},\n",
303
+ " {'task_id':'task_10_failure','ad_key':'task_10_failure','step_count':2,'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),'last_feedback':'policy_confidence=0.85','signals':{'policy_confidence':0.85}},\n",
304
+ " {'task_id':'task_10_failure','ad_key':'task_10_failure','step_count':3,'actions_already_taken':['query_regulations','check_advertiser_history'],'setup_actions':_sa('query_regulations','check_advertiser_history'),'last_feedback':'risk_score=0.80','signals':{'policy_confidence':0.85,'risk_score':0.80}},\n",
305
+ " {'task_id':'task_10_failure','ad_key':'task_10_failure','step_count':4,'actions_already_taken':['query_regulations','check_advertiser_history','submit_audit'],'setup_actions':_sa('query_regulations','check_advertiser_history','submit_audit'),'last_feedback':'audit_logged id=AUD-010','signals':{'policy_confidence':0.85,'risk_score':0.80}},\n",
306
+ "]\n",
307
+ "\n",
308
+ "def build_observation(scenario):\n",
309
+ " ad = TASK_AD_DATA[scenario['ad_key']]\n",
310
+ " sigs = scenario.get('signals', {})\n",
311
+ " return {\n",
312
+ " 'task_id': scenario['task_id'],\n",
313
+ " 'last_feedback': scenario['last_feedback'],\n",
314
+ " 'step_count': scenario['step_count'],\n",
315
+ " 'actions_already_taken': scenario['actions_already_taken'],\n",
316
+ " 'ad_details': {\n",
317
+ " **ad,\n",
318
+ " 'status_message': scenario['last_feedback'],\n",
319
+ " 'reward': 0.0, 'done': False,\n",
320
+ " 'risk_score': sigs.get('risk_score'),\n",
321
+ " 'policy_confidence': sigs.get('policy_confidence'),\n",
322
+ " 'image_flag': sigs.get('image_flag'),\n",
323
+ " 'landing_flag': sigs.get('landing_flag'),\n",
324
+ " 'last_error': sigs.get('last_error'),\n",
325
+ " },\n",
326
+ " }\n",
327
+ "\n",
328
+ "def build_dataset():\n",
329
+ " rows = []\n",
330
+ " for s in BASE_SCENARIOS:\n",
331
+ " obs = build_observation(s)\n",
332
+ " user_content = 'Current Ad Observation:\\n' + json.dumps(obs, indent=2) + '\\n\\nWhat is your next action?'\n",
333
+ " rows.append({\n",
334
+ " 'prompt': [\n",
335
+ " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
336
+ " {'role': 'user', 'content': user_content},\n",
337
+ " ],\n",
338
+ " 'task_id': s['task_id'],\n",
339
+ " 'setup_actions': s['setup_actions'],\n",
340
+ " })\n",
341
+ " return Dataset.from_list(rows * 8)\n",
342
+ "\n",
343
+ "dataset = build_dataset()\n",
344
+ "print(f'Dataset: {len(dataset)} examples ({len(BASE_SCENARIOS)} unique scenarios x 8)')"
345
+ ],
346
+ "execution_count": null,
347
+ "outputs": []
348
+ },
349
+ {
350
+ "cell_type": "markdown",
351
+ "metadata": {},
352
+ "source": [
353
+ "## Cell 7 — Reward Function"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "metadata": {},
359
+ "source": [
360
+ "reward_log = []\n",
361
+ "step_log = []\n",
362
+ "global_step_counter = [0]\n",
363
+ "\n",
364
+ "def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):\n",
365
+ " client = EnvClient(ENV_URL)\n",
366
+ " rewards = []\n",
367
+ "\n",
368
+ " if task_id is None or setup_actions is None:\n",
369
+ " return [-1.0] * len(completions)\n",
370
+ "\n",
371
+ " for idx, (completion, t_id, setup) in enumerate(zip(completions, task_id, setup_actions)):\n",
372
+ " parsed = extract_json(completion)\n",
373
+ " if not parsed:\n",
374
+ " rewards.append(-1.0)\n",
375
+ " continue\n",
376
+ "\n",
377
+ " action_type = parsed.get('action_type')\n",
378
+ " if action_type not in ALLOWED_ACTIONS:\n",
379
+ " rewards.append(-1.0)\n",
380
+ " continue\n",
381
+ "\n",
382
+ " action = {'action_type': action_type, 'reasoning': parsed.get('reasoning', 'format-compliant')}\n",
383
+ "\n",
384
+ " try:\n",
385
+ " random.seed(hash((t_id, len(setup))) % (2**32 - 1))\n",
386
+ " client.reset(t_id)\n",
387
+ " for s in setup:\n",
388
+ " safe_step(client, s)\n",
389
+ "\n",
390
+ " result = safe_step(client, action)\n",
391
+ " env_reward = float(result.get('reward', -0.2))\n",
392
+ " status_msg = (result.get('status_message') or '').lower()\n",
393
+ "\n",
394
+ " rejected = (\n",
395
+ " 'api failure' in status_msg\n",
396
+ " or 'invalid action' in status_msg\n",
397
+ " or 'must call' in status_msg\n",
398
+ " )\n",
399
+ "\n",
400
+ " if rejected:\n",
401
+ " shaped = -0.5\n",
402
+ " else:\n",
403
+ " shaped = 0.5 + env_reward\n",
404
+ " taken = set(a['action_type'] for a in setup)\n",
405
+ "\n",
406
+ " if not taken:\n",
407
+ " if action_type == 'query_regulations':\n",
408
+ " shaped += 0.15\n",
409
+ " elif 'submit_audit' in taken:\n",
410
+ " if action_type in ('approve', 'reject'):\n",
411
+ " shaped += 0.2\n",
412
+ " else:\n",
413
+ " shaped -= 0.1\n",
414
+ " elif 'query_regulations' in taken:\n",
415
+ " gathering = {'analyze_image','check_advertiser_history','request_landing_page','request_id_verification'}\n",
416
+ " if action_type in gathering:\n",
417
+ " shaped += 0.1\n",
418
+ " elif action_type == 'submit_audit':\n",
419
+ " shaped += 0.1\n",
420
+ " elif action_type in ('approve', 'reject'):\n",
421
+ " shaped -= 0.15\n",
422
+ "\n",
423
+ " if t_id == 'task_3_multimodal' and action_type == 'analyze_image':\n",
424
+ " shaped += 0.1\n",
425
+ " if t_id == 'task_4_targeting' and action_type == 'request_id_verification':\n",
426
+ " shaped += 0.1\n",
427
+ " if t_id in ('task_8_adversarial', 'task_9_dependency_trap') and action_type == 'analyze_image':\n",
428
+ " shaped += 0.1\n",
429
+ "\n",
430
+ " rewards.append(shaped)\n",
431
+ " except Exception:\n",
432
+ " rewards.append(-0.3)\n",
433
+ "\n",
434
+ " avg = sum(rewards) / len(rewards) if rewards else 0.0\n",
435
+ " global_step_counter[0] += 1\n",
436
+ " reward_log.append(avg)\n",
437
+ " step_log.append(global_step_counter[0])\n",
438
+ " return rewards\n",
439
+ "\n",
440
+ "print('Reward function ready')"
441
+ ],
442
+ "execution_count": null,
443
+ "outputs": []
444
+ },
445
+ {
446
+ "cell_type": "markdown",
447
+ "metadata": {},
448
+ "source": [
449
+ "## Cell 8 — Load Model"
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "code",
454
+ "metadata": {},
455
+ "source": [
456
+ "if torch.cuda.is_available():\n",
457
+ " _props = torch.cuda.get_device_properties(0)\n",
458
+ " _vram = _props.total_memory\n",
459
+ " _cc = (_props.major, _props.minor)\n",
460
+ " print(f'GPU: {_props.name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}')\n",
461
+ "else:\n",
462
+ " _vram, _cc = 0, (0, 0)\n",
463
+ "\n",
464
+ "USE_4BIT = _vram < 40 * 1024**3 # T4/L4 → 4-bit; A100 → full precision\n",
465
+ "USE_BF16 = _cc >= (8, 0) and not USE_4BIT # bf16 only with full-precision weights; 4-bit LoRA uses fp16\n",
466
+ "print(f'4-bit: {USE_4BIT} bf16: {USE_BF16}')\n",
467
+ "\n",
468
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
469
+ " model_name='unsloth/Llama-3.1-8B-Instruct',\n",
470
+ " load_in_4bit=USE_4BIT,\n",
471
+ " max_seq_length=2048,\n",
472
+ " dtype=torch.float16 if USE_4BIT else None,\n",
473
+ ")\n",
474
+ "\n",
475
+ "model = FastLanguageModel.get_peft_model(\n",
476
+ " model,\n",
477
+ " r=16 if USE_4BIT else 32,\n",
478
+ " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
479
+ " lora_alpha=32 if USE_4BIT else 64,\n",
480
+ " lora_dropout=0,\n",
481
+ " bias='none',\n",
482
+ " use_gradient_checkpointing='unsloth',\n",
483
+ " random_state=3407,\n",
484
+ ")\n",
485
+ "print('Model loaded')"
486
+ ],
487
+ "execution_count": null,
488
+ "outputs": []
489
+ },
490
+ {
491
+ "cell_type": "markdown",
492
+ "metadata": {},
493
+ "source": [
494
+ "## Cell 9 — Train"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "metadata": {},
500
+ "source": [
501
+ "trainer = GRPOTrainer(\n",
502
+ " model=model,\n",
503
+ " reward_funcs=[reward_environment],\n",
504
+ " args=GRPOConfig(\n",
505
+ " output_dir='outputs',\n",
506
+ " learning_rate=5e-6,\n",
507
+ " num_train_epochs=1 if USE_4BIT else 2,\n",
508
+ " per_device_train_batch_size=1 if USE_4BIT else 2,\n",
509
+ " gradient_accumulation_steps=4,\n",
510
+ " num_generations=4,\n",
511
+ " max_prompt_length=512,\n",
512
+ " max_completion_length=80,\n",
513
+ " logging_steps=5,\n",
514
+ " warmup_steps=10,\n",
515
+ " bf16=USE_BF16,\n",
516
+ " fp16=not USE_BF16,\n",
517
+ " report_to='none',\n",
518
+ " ),\n",
519
+ " train_dataset=dataset,\n",
520
+ " tokenizer=tokenizer,\n",
521
+ ")\n",
522
+ "\n",
523
+ "print('Starting GRPO training...')\n",
524
+ "print(f' lr=5e-6 bf16={USE_BF16} fp16={not USE_BF16} batch={1 if USE_4BIT else 2} gens=4 epochs={1 if USE_4BIT else 2}')\n",
525
+ "trainer.train()\n",
526
+ "print('Training complete')"
527
+ ],
528
+ "execution_count": null,
529
+ "outputs": []
530
+ },
531
+ {
532
+ "cell_type": "markdown",
533
+ "metadata": {},
534
+ "source": [
535
+ "## Cell 10 — Plot Reward Curve"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "metadata": {},
541
+ "source": [
542
+ "import matplotlib.pyplot as plt\n",
543
+ "import numpy as np\n",
544
+ "import pandas as pd\n",
545
+ "import os\n",
546
+ "\n",
547
+ "os.makedirs('outputs', exist_ok=True)\n",
548
+ "\n",
549
+ "def moving_avg(data, window=5):\n",
550
+ " if len(data) < window:\n",
551
+ " return data\n",
552
+ " return list(np.convolve(data, np.ones(window)/window, mode='valid'))\n",
553
+ "\n",
554
+ "hist = pd.DataFrame(trainer.state.log_history)\n",
555
+ "\n",
556
+ "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
557
+ "\n",
558
+ "# --- Plot 1: Reward curve (from our custom log) ---\n",
559
+ "ax = axes[0]\n",
560
+ "ax.plot(step_log, reward_log, alpha=0.3, color='steelblue', label='Raw')\n",
561
+ "smoothed = moving_avg(reward_log)\n",
562
+ "ax.plot(range(len(smoothed)), smoothed, color='steelblue', linewidth=2, label='Smoothed (MA-5)')\n",
563
+ "ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n",
564
+ "ax.set_xlabel('Reward Eval Step')\n",
565
+ "ax.set_ylabel('Avg Reward per Batch')\n",
566
+ "ax.set_title('Reward Curve')\n",
567
+ "ax.legend()\n",
568
+ "ax.grid(alpha=0.3)\n",
569
+ "\n",
570
+ "# --- Plot 2: Loss curve (from trainer logs) ---\n",
571
+ "ax = axes[1]\n",
572
+ "loss_rows = hist.dropna(subset=['loss']) if 'loss' in hist.columns else pd.DataFrame()\n",
573
+ "if not loss_rows.empty:\n",
574
+ " ax.plot(loss_rows['step'], loss_rows['loss'], color='#7c3aed', linewidth=2)\n",
575
+ " ax.set_xlabel('Training Step')\n",
576
+ " ax.set_ylabel('Loss')\n",
577
+ " ax.set_title('GRPO Loss')\n",
578
+ " ax.grid(alpha=0.3)\n",
579
+ "else:\n",
580
+ " ax.text(0.5, 0.5, 'No loss data logged', ha='center', va='center', transform=ax.transAxes)\n",
581
+ " ax.set_title('GRPO Loss')\n",
582
+ "\n",
583
+ "# --- Plot 3: Reward from trainer logs (if available) ---\n",
584
+ "ax = axes[2]\n",
585
+ "reward_cols = [c for c in hist.columns if 'reward' in c.lower() and 'std' not in c.lower()]\n",
586
+ "if reward_cols:\n",
587
+ " col = reward_cols[0]\n",
588
+ " rr = hist.dropna(subset=[col])\n",
589
+ " ax.plot(rr['step'], rr[col], color='#16a34a', linewidth=2)\n",
590
+ " ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n",
591
+ " ax.set_xlabel('Training Step')\n",
592
+ " ax.set_ylabel(col)\n",
593
+ " ax.set_title('Trainer Reward Log')\n",
594
+ " ax.grid(alpha=0.3)\n",
595
+ "else:\n",
596
+ " ax.text(0.5, 0.5, 'No trainer reward data', ha='center', va='center', transform=ax.transAxes)\n",
597
+ " ax.set_title('Trainer Reward Log')\n",
598
+ "\n",
599
+ "plt.tight_layout()\n",
600
+ "plt.savefig('outputs/training_plots.png', dpi=150)\n",
601
+ "plt.show()\n",
602
+ "print('Saved to outputs/training_plots.png')\n",
603
+ "\n",
604
+ "n = len(reward_log)\n",
605
+ "first_10 = reward_log[:min(10, n)]\n",
606
+ "last_10 = reward_log[max(0, n-10):]\n",
607
+ "print(f'\\n--- Before vs After ---')\n",
608
+ "print(f'Avg reward (first 10 steps): {sum(first_10)/len(first_10):.3f}')\n",
609
+ "print(f'Avg reward (last 10 steps) : {sum(last_10)/len(last_10):.3f}')"
610
+ ],
611
+ "execution_count": null,
612
+ "outputs": []
613
+ },
614
+ {
615
+ "cell_type": "markdown",
616
+ "metadata": {},
617
+ "source": [
618
+ "## Cell 11 — Before vs After: Baseline Comparison"
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "code",
623
+ "metadata": {},
624
+ "source": [
625
+ "from unsloth import FastLanguageModel as FLM\n",
626
+ "FLM.for_inference(model)\n",
627
+ "\n",
628
+ "EVAL_SCENARIOS = [\n",
629
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare','step_count':1,\n",
630
+ " 'actions_already_taken':[],'setup_actions':[],\n",
631
+ " 'last_feedback':'Ad loaded for task_1_healthcare. Begin with query_regulations.',\n",
632
+ " 'signals':{},'expected':'query_regulations'},\n",
633
+ " {'task_id':'task_2_financial','ad_key':'task_2_financial','step_count':1,\n",
634
+ " 'actions_already_taken':[],'setup_actions':[],\n",
635
+ " 'last_feedback':'Ad loaded for task_2_financial. Begin with query_regulations.',\n",
636
+ " 'signals':{},'expected':'query_regulations'},\n",
637
+ " {'task_id':'task_3_multimodal','ad_key':'task_3_multimodal','step_count':2,\n",
638
+ " 'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),\n",
639
+ " 'last_feedback':'policy_confidence=0.65',\n",
640
+ " 'signals':{'policy_confidence':0.65},'expected':'analyze_image'},\n",
641
+ " {'task_id':'task_4_targeting','ad_key':'task_4_targeting','step_count':2,\n",
642
+ " 'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),\n",
643
+ " 'last_feedback':'policy_confidence=0.70',\n",
644
+ " 'signals':{'policy_confidence':0.70},'expected':'request_id_verification'},\n",
645
+ " {'task_id':'task_1_healthcare','ad_key':'task_1_healthcare','step_count':3,\n",
646
+ " 'actions_already_taken':['query_regulations','check_advertiser_history'],\n",
647
+ " 'setup_actions':_sa('query_regulations','check_advertiser_history'),\n",
648
+ " 'last_feedback':'risk_score=0.82',\n",
649
+ " 'signals':{'policy_confidence':0.92,'risk_score':0.82},'expected':'submit_audit'},\n",
650
+ " {'task_id':'task_2_financial','ad_key':'task_2_financial','step_count':4,\n",
651
+ " 'actions_already_taken':['query_regulations','check_advertiser_history','submit_audit'],\n",
652
+ " 'setup_actions':_sa('query_regulations','check_advertiser_history','submit_audit'),\n",
653
+ " 'last_feedback':'audit_logged id=AUD-002',\n",
654
+ " 'signals':{'policy_confidence':0.88,'risk_score':0.75},'expected':'reject'},\n",
655
+ " {'task_id':'task_7_ambiguous','ad_key':'task_7_ambiguous','step_count':2,\n",
656
+ " 'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),\n",
657
+ " 'last_feedback':'policy_confidence=0.42',\n",
658
+ " 'signals':{'policy_confidence':0.42},'expected':'check_advertiser_history'},\n",
659
+ " {'task_id':'task_8_adversarial','ad_key':'task_8_adversarial','step_count':2,\n",
660
+ " 'actions_already_taken':['query_regulations'],'setup_actions':_sa('query_regulations'),\n",
661
+ " 'last_feedback':'policy_confidence=0.75',\n",
662
+ " 'signals':{'policy_confidence':0.75},'expected':'analyze_image'},\n",
663
+ "]\n",
664
+ "\n",
665
+ "def eval_model_on_scenarios(model, tokenizer, scenarios, label='Model'):\n",
666
+ " json_ok, phase_ok, total = 0, 0, len(scenarios)\n",
667
+ " results = []\n",
668
+ " for s in scenarios:\n",
669
+ " obs = build_observation(s)\n",
670
+ " user_content = 'Current Ad Observation:\\n' + json.dumps(obs, indent=2) + '\\n\\nWhat is your next action?'\n",
671
+ " messages = [{'role':'system','content':SYSTEM_PROMPT},{'role':'user','content':user_content}]\n",
672
+ " prompt_str = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
673
+ " inputs = tokenizer(prompt_str, return_tensors='pt').to('cuda')\n",
674
+ " out = model.generate(**inputs, max_new_tokens=64, temperature=0.1, do_sample=True)\n",
675
+ " decoded = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
676
+ " parsed = extract_json(decoded)\n",
677
+ " is_json = parsed is not None\n",
678
+ " action = parsed.get('action_type','') if parsed else ''\n",
679
+ " is_correct = action == s['expected']\n",
680
+ " json_ok += int(is_json)\n",
681
+ " phase_ok += int(is_correct)\n",
682
+ " results.append({'task':s['task_id'],'expected':s['expected'],'got':action,'ok':is_correct})\n",
683
+ " print(f'\\n=== {label} ({total} scenarios) ===')\n",
684
+ " print(f' JSON parse rate : {json_ok}/{total} ({100*json_ok/total:.0f}%)')\n",
685
+ " print(f' Correct action : {phase_ok}/{total} ({100*phase_ok/total:.0f}%)')\n",
686
+ " for r in results:\n",
687
+ " mark = 'OK' if r['ok'] else 'MISS'\n",
688
+ " print(f\" [{mark}] {r['task']:25s} expected={r['expected']:30s} got={r['got']}\")\n",
689
+ " return {'json_rate': json_ok/total, 'phase_rate': phase_ok/total}\n",
690
+ "\n",
691
+ "trained_metrics = eval_model_on_scenarios(model, tokenizer, EVAL_SCENARIOS, 'Trained Model')\n",
692
+ "\n",
693
+ "fig, ax = plt.subplots(figsize=(8, 4))\n",
694
+ "metrics = ['JSON Parse Rate', 'Correct Phase Action']\n",
695
+ "trained_vals = [trained_metrics['json_rate'], trained_metrics['phase_rate']]\n",
696
+ "x = range(len(metrics))\n",
697
+ "bars = ax.bar(x, trained_vals, width=0.5, color='#2563eb', label='After GRPO')\n",
698
+ "ax.set_xticks(x)\n",
699
+ "ax.set_xticklabels(metrics)\n",
700
+ "ax.set_ylim(0, 1.05)\n",
701
+ "ax.set_ylabel('Rate')\n",
702
+ "ax.set_title('Trained Model Evaluation')\n",
703
+ "for bar, val in zip(bars, trained_vals):\n",
704
+ " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{val:.0%}', ha='center', fontweight='bold')\n",
705
+ "ax.legend()\n",
706
+ "ax.grid(axis='y', alpha=0.3)\n",
707
+ "plt.tight_layout()\n",
708
+ "plt.savefig('outputs/before_after_comparison.png', dpi=150)\n",
709
+ "plt.show()\n",
710
+ "print('Saved to outputs/before_after_comparison.png')"
711
+ ],
712
+ "execution_count": null,
713
+ "outputs": []
714
+ },
715
+ {
716
+ "cell_type": "markdown",
717
+ "metadata": {},
718
+ "source": [
719
+ "## Cell 12 — Save + Push to HF Hub"
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "code",
724
+ "metadata": {},
725
+ "source": [
726
+ "model.save_pretrained('outputs/lora_adapter')\n",
727
+ "tokenizer.save_pretrained('outputs/lora_adapter')\n",
728
+ "print('LoRA adapter saved')\n",
729
+ "\n",
730
+ "print('Merging adapter into base model...')\n",
731
+ "merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n",
732
+ " model_name='outputs/lora_adapter',\n",
733
+ " load_in_4bit=False,\n",
734
+ " max_seq_length=2048,\n",
735
+ ")\n",
736
+ "merged_model.save_pretrained_merged(\n",
737
+ " 'outputs/merged',\n",
738
+ " merged_tokenizer,\n",
739
+ " save_method='merged_16bit',\n",
740
+ ")\n",
741
+ "print('Merged model saved to outputs/merged')\n",
742
+ "\n",
743
+ "if HF_REPO and HF_TOKEN:\n",
744
+ " print(f'Pushing to {HF_REPO}...')\n",
745
+ " merged_model.push_to_hub_merged(\n",
746
+ " HF_REPO,\n",
747
+ " merged_tokenizer,\n",
748
+ " save_method='merged_16bit',\n",
749
+ " token=HF_TOKEN,\n",
750
+ " )\n",
751
+ " print(f'Model live at https://huggingface.co/{HF_REPO}')\n",
752
+ "else:\n",
753
+ " print('Set HF_REPO and HF_TOKEN in Cell 3 to push to Hub')\n",
754
+ "\n",
755
+ "print('Done.')"
756
+ ],
757
+ "execution_count": null,
758
+ "outputs": []
759
+ }
760
+ ],
761
+ "metadata": {
762
+ "colab": {
763
+ "provenance": [],
764
+ "gpuType": "A100"
765
+ },
766
+ "kernelspec": {
767
+ "display_name": "Python 3",
768
+ "name": "python3"
769
+ },
770
+ "language_info": {
771
+ "name": "python"
772
+ },
773
+ "accelerator": "GPU"
774
+ },
775
+ "nbformat": 4,
776
+ "nbformat_minor": 0
777
+ }