parth-1 commited on
Commit
70c7c72
·
verified ·
1 Parent(s): e4a8c57

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +441 -207
grpo_train.py CHANGED
@@ -13,33 +13,18 @@ from trl import GRPOTrainer, GRPOConfig
13
 
14
  PatchFastRL("GRPO", FastLanguageModel)
15
 
16
- # #region agent log
17
- import pathlib as _pl
18
- _DLOG = _pl.Path("debug-851b5f.log")
19
- def _dlog(hyp, loc, msg, data=None):
20
- import time as _t
21
- entry = json.dumps({"sessionId":"851b5f","hypothesisId":hyp,"location":loc,"message":msg,"data":data or {},"timestamp":int(_t.time()*1000)})
22
- with open(_DLOG, "a") as f: f.write(entry + "\n")
23
- print(f"[DBG:{hyp}] {msg} {data or ''}", flush=True)
24
- # #endregion
25
-
26
  # =========================
27
  # CONFIG
28
  # =========================
29
 
30
  ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
31
  HF_TOKEN = os.getenv("HF_TOKEN", "")
32
- HF_REPO = os.getenv("HF_REPO", "") # e.g. "yourname/metaguard-llama3.1-8b-grpo"
33
 
34
  ALLOWED_ACTIONS = [
35
- "query_regulations",
36
- "analyze_image",
37
- "check_advertiser_history",
38
- "request_landing_page",
39
- "request_id_verification",
40
- "submit_audit",
41
- "approve",
42
- "reject",
43
  ]
44
 
45
  # =========================
@@ -47,35 +32,23 @@ ALLOWED_ACTIONS = [
47
  # =========================
48
 
49
  def ensure_env_ready():
50
- # #region agent log
51
- _dlog("B", "grpo_train.py:ensure_env_ready", "Checking env", {"ENV_URL": ENV_URL})
52
- # #endregion
53
  for i in range(20):
54
  try:
55
  r = requests.post(
56
  f"{ENV_URL}/reset",
57
  json={"task_id": "task_1_healthcare"},
58
- timeout=5
59
  )
60
  if r.status_code == 200:
61
- # #region agent log
62
- _dlog("B", "grpo_train.py:ensure_env_ready", "Env ready", {"attempt": i+1, "status": r.status_code})
63
- # #endregion
64
- print("✅ Environment ready")
65
  return
66
- except Exception as e:
67
- # #region agent log
68
- if i == 0: _dlog("B", "grpo_train.py:ensure_env_ready", "Env connection failed", {"attempt": i+1, "error": str(e)[:200]})
69
- # #endregion
70
  pass
71
  time.sleep(1)
72
- # #region agent log
73
- _dlog("B", "grpo_train.py:ensure_env_ready", "ENV UNREACHABLE after 20 attempts", {})
74
- # #endregion
75
- raise RuntimeError("❌ ENV not reachable")
76
 
77
  # =========================
78
- # SAFE CLIENT
79
  # =========================
80
 
81
  class EnvClient:
@@ -84,30 +57,22 @@ class EnvClient:
84
 
85
  def reset(self, task_id):
86
  return requests.post(
87
- f"{self.url}/reset",
88
- json={"task_id": task_id},
89
- timeout=8
90
  ).json()
91
 
92
  def step(self, action):
93
  return requests.post(
94
- f"{self.url}/step",
95
- json={"action": action},
96
- timeout=8
97
  ).json()
98
 
99
  def safe_step(client, action):
100
  for _ in range(3):
101
  try:
102
  return client.step(action)
103
- except:
104
  time.sleep(0.5)
105
  return {"reward": -0.3}
106
 
107
- # =========================
108
- # JSON PARSER
109
- # =========================
110
-
111
  def extract_json(text):
112
  try:
113
  if "```" in text:
@@ -115,186 +80,436 @@ def extract_json(text):
115
  if text.startswith("json"):
116
  text = text[4:]
117
  return json.loads(text.strip())
118
- except:
119
  return None
120
 
121
  # =========================
122
- # DATASET (WITH SETUP ACTIONS)
123
  # =========================
124
 
125
- BASE_SCENARIOS = [
126
- # Phase 1 Fresh state, expected: query_regulations
127
- {
128
- "task_id": "task_1_healthcare",
129
- "text": "Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.",
130
- "actions_already_taken": [],
131
- "setup_actions": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  },
133
- {
134
- "task_id": "task_2_financial",
135
- "text": "Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.",
136
- "actions_already_taken": [],
137
- "setup_actions": [],
 
138
  },
139
- {
140
- "task_id": "task_3_multimodal",
141
- "text": "Multimodal ad: image may contain hidden violation. No actions taken yet.",
142
- "actions_already_taken": [],
143
- "setup_actions": [],
 
144
  },
145
-
146
- # Phase 2 — Policy checked, expected: analyze_image OR check_advertiser_history
147
- {
148
- "task_id": "task_1_healthcare",
149
- "text": "Healthcare ad: pharma product. Policy already queried.",
150
- "actions_already_taken": ["query_regulations"],
151
- "setup_actions": [
152
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
153
- ],
154
  },
155
- {
156
- "task_id": "task_3_multimodal",
157
- "text": "Multimodal ad: image not yet inspected. Policy already queried.",
158
- "actions_already_taken": ["query_regulations"],
159
- "setup_actions": [
160
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
161
- ],
162
  },
163
-
164
- # Phase 3 Policy + history checked, expected: submit_audit
165
- {
166
- "task_id": "task_2_financial",
167
- "text": "Financial ad: investment scheme. Policy and advertiser history both checked.",
168
- "actions_already_taken": ["query_regulations", "check_advertiser_history"],
169
- "setup_actions": [
170
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
171
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
172
- ],
173
  },
174
-
175
- # Phase 4 — Audit complete, expected: reject (high-risk) or approve (clean)
176
- {
177
- "task_id": "task_2_financial",
178
- "text": "Financial ad: investment scheme. Policy, history, and audit all complete. Make final decision.",
179
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
180
- "setup_actions": [
181
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
182
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
183
- {"action_type": "submit_audit", "reasoning": "audit log"},
184
- ],
185
  },
186
-
187
- # Targeting task fresh state, expected: query_regulations
188
- {
189
- "task_id": "task_4_targeting",
190
- "text": "Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.",
191
- "actions_already_taken": [],
192
- "setup_actions": [],
193
  },
194
- # Targeting task — mid state, expected: request_id_verification (age check)
195
- {
196
- "task_id": "task_4_targeting",
197
- "text": "Financial ad targeting young users. Policy queried, need to verify age targeting.",
198
- "actions_already_taken": ["query_regulations"],
199
- "setup_actions": [
200
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
201
- ],
202
  },
203
- # Targeting task — audit ready
204
- {
205
- "task_id": "task_4_targeting",
206
- "text": "Financial ad targeting minors. Policy, advertiser history, and ID verification done. Submit audit.",
207
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_id_verification"],
208
- "setup_actions": [
209
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
210
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
211
- {"action_type": "request_id_verification", "reasoning": "age check"},
212
- ],
213
  },
214
- ]
215
-
216
- PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.
217
-
218
- You MUST choose exactly ONE action_type from this list (any other value is invalid):
219
- - query_regulations
220
- - analyze_image
221
- - check_advertiser_history
222
- - request_landing_page
223
- - request_id_verification
224
- - submit_audit
225
- - approve
226
- - reject
 
 
227
 
228
- REQUIRED PHASE ORDER:
229
- 1. query_regulations -> always first
230
- 2. analyze_image / check_advertiser_history -> gather signals
231
- 3. submit_audit -> always before final decision
232
- 4. approve OR reject -> only after audit
233
 
234
- HARD RULES:
235
- - NEVER repeat an action listed in `actions_already_taken`.
236
- - Respond with ONLY a valid JSON object. No markdown, no prose.
237
 
238
- Required format:
239
- {{"action_type": "<one_of_the_actions_above>", "reasoning": "<short reason>"}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- Scenario: {text}
242
- actions_already_taken: {actions_already_taken}
 
243
 
244
- Your next action?"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
 
247
  def build_dataset():
248
  rows = []
249
  for s in BASE_SCENARIOS:
250
- prompt = PROMPT_TEMPLATE.format(
251
- text=s["text"],
252
- actions_already_taken=json.dumps(s["actions_already_taken"]),
 
 
253
  )
254
  rows.append({
255
- "prompt": prompt,
 
 
 
256
  "task_id": s["task_id"],
257
  "setup_actions": s["setup_actions"],
258
  })
259
- return Dataset.from_list(rows * 10) # 10 scenarios x 10 = 100 examples
260
 
261
  # =========================
262
- # REWARD FUNCTION (FIXED)
263
  # =========================
264
 
265
- _reward_call_count = [0]
266
-
267
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
268
- """Shaped reward for GRPO."""
269
- _reward_call_count[0] += 1
270
- _call = _reward_call_count[0]
271
- # #region agent log
272
- _dlog("C", "grpo_train.py:reward_env", f"reward call #{_call}", {
273
- "n_prompts": len(prompts) if prompts else 0,
274
- "n_completions": len(completions) if completions else 0,
275
- "completions_type": type(completions).__name__,
276
- "first_completion_type": type(completions[0]).__name__ if completions else "N/A",
277
- "first_completion_preview": str(completions[0])[:150] if completions else "N/A",
278
- "task_id_is_none": task_id is None,
279
- "setup_actions_is_none": setup_actions is None,
280
- "kwargs_keys": list(kwargs.keys()),
281
- })
282
- # #endregion
283
-
284
  client = EnvClient(ENV_URL)
285
  rewards = []
286
 
287
  if task_id is None or setup_actions is None:
288
- # #region agent log
289
- _dlog("D", "grpo_train.py:reward_env", "task_id or setup_actions is None — returning -1 for all", {"call": _call})
290
- # #endregion
291
  return [-1.0] * len(completions)
292
 
293
  for idx, (completion, t_id, setup) in enumerate(zip(completions, task_id, setup_actions)):
294
  parsed = extract_json(completion)
295
- # #region agent log
296
- if _call <= 3: _dlog("D", "grpo_train.py:reward_loop", f"call#{_call} item#{idx}", {"parsed_ok": parsed is not None, "action": parsed.get("action_type") if parsed else None, "raw_preview": str(completion)[:120], "task_id": t_id})
297
- # #endregion
298
  if not parsed:
299
  rewards.append(-1.0)
300
  continue
@@ -310,6 +525,7 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
310
  }
311
 
312
  try:
 
313
  client.reset(t_id)
314
  for s in setup:
315
  safe_step(client, s)
@@ -329,6 +545,35 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
329
  else:
330
  shaped = 0.5 + env_reward
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  rewards.append(shaped)
333
 
334
  except Exception:
@@ -344,19 +589,15 @@ if torch.cuda.is_available():
344
  _props = torch.cuda.get_device_properties(0)
345
  _vram = _props.total_memory
346
  _name = _props.name
347
- _cc = (_props.major, _props.minor) # compute capability
348
  print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}")
349
  else:
350
  _vram = 0
351
  _name = "CPU"
352
  _cc = (0, 0)
353
 
354
- USE_4BIT = _vram < 40 * 1024**3 # T4 (15 GB), L4 (24 GB) → 4-bit; A100 (80 GB) → full
355
- USE_BF16 = _cc >= (8, 0) and not USE_4BIT # bf16 only when full-precision; 4-bit LoRA uses fp16 internally
356
-
357
- # #region agent log
358
- _dlog("A", "grpo_train.py:gpu_detect", "GPU config resolved", {"name":_name,"vram_gb":round(_vram/1024**3,1),"cc":list(_cc),"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16})
359
- # #endregion
360
 
361
  model, tokenizer = FastLanguageModel.from_pretrained(
362
  model_name="unsloth/Llama-3.1-8B-Instruct",
@@ -385,24 +626,20 @@ model = FastLanguageModel.get_peft_model(
385
 
386
  dataset = build_dataset()
387
 
388
- # #region agent log
389
- _dlog("A", "grpo_train.py:trainer_init", "Creating GRPOTrainer", {"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16,"epochs":1 if USE_4BIT else 3,"batch":1 if USE_4BIT else 2,"gens":2 if USE_4BIT else 4,"dataset_len":len(dataset)})
390
- # #endregion
391
-
392
  trainer = GRPOTrainer(
393
  model=model,
394
  reward_funcs=[reward_environment],
395
  args=GRPOConfig(
396
  output_dir="outputs",
397
- learning_rate=2e-5,
398
- num_train_epochs=1 if USE_4BIT else 3,
399
  per_device_train_batch_size=1 if USE_4BIT else 2,
400
- gradient_accumulation_steps=2 if USE_4BIT else 4,
401
- num_generations=2 if USE_4BIT else 4,
402
- max_prompt_length=768,
403
- max_completion_length=128,
404
- logging_steps=3 if USE_4BIT else 5,
405
- warmup_steps=5 if USE_4BIT else 10,
406
  bf16=USE_BF16,
407
  fp16=not USE_BF16,
408
  report_to="none",
@@ -418,9 +655,6 @@ trainer = GRPOTrainer(
418
  if __name__ == "__main__":
419
  ensure_env_ready()
420
 
421
- # #region agent log
422
- _dlog("E", "grpo_train.py:train_start", "About to call trainer.train()", {"gpu_mem_allocated_gb": round(torch.cuda.memory_allocated()/1024**3, 2) if torch.cuda.is_available() else 0})
423
- # #endregion
424
  print("Starting GRPO training...")
425
  trainer.train()
426
 
@@ -428,7 +662,7 @@ if __name__ == "__main__":
428
  tokenizer.save_pretrained("outputs/lora_adapter")
429
  print("LoRA adapter saved to outputs/lora_adapter")
430
 
431
- print("Merging adapter into base model (bf16)...")
432
  merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
433
  model_name="outputs/lora_adapter",
434
  load_in_4bit=False,
 
13
 
14
  PatchFastRL("GRPO", FastLanguageModel)
15
 
 
 
 
 
 
 
 
 
 
 
16
  # =========================
17
  # CONFIG
18
  # =========================
19
 
20
  ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
21
  HF_TOKEN = os.getenv("HF_TOKEN", "")
22
+ HF_REPO = os.getenv("HF_REPO", "")
23
 
24
  ALLOWED_ACTIONS = [
25
+ "query_regulations", "analyze_image", "check_advertiser_history",
26
+ "request_landing_page", "request_id_verification",
27
+ "submit_audit", "approve", "reject",
 
 
 
 
 
28
  ]
29
 
30
  # =========================
 
32
  # =========================
33
 
34
  def ensure_env_ready():
 
 
 
35
  for i in range(20):
36
  try:
37
  r = requests.post(
38
  f"{ENV_URL}/reset",
39
  json={"task_id": "task_1_healthcare"},
40
+ timeout=5,
41
  )
42
  if r.status_code == 200:
43
+ print("Environment ready")
 
 
 
44
  return
45
+ except Exception:
 
 
 
46
  pass
47
  time.sleep(1)
48
+ raise RuntimeError("ENV not reachable after 20 attempts")
 
 
 
49
 
50
  # =========================
51
+ # ENV CLIENT
52
  # =========================
53
 
54
  class EnvClient:
 
57
 
58
  def reset(self, task_id):
59
  return requests.post(
60
+ f"{self.url}/reset", json={"task_id": task_id}, timeout=8,
 
 
61
  ).json()
62
 
63
  def step(self, action):
64
  return requests.post(
65
+ f"{self.url}/step", json={"action": action}, timeout=8,
 
 
66
  ).json()
67
 
68
  def safe_step(client, action):
69
  for _ in range(3):
70
  try:
71
  return client.step(action)
72
+ except Exception:
73
  time.sleep(0.5)
74
  return {"reward": -0.3}
75
 
 
 
 
 
76
  def extract_json(text):
77
  try:
78
  if "```" in text:
 
80
  if text.startswith("json"):
81
  text = text[4:]
82
  return json.loads(text.strip())
83
+ except Exception:
84
  return None
85
 
86
  # =========================
87
+ # SYSTEM PROMPT (identical to inference.py)
88
  # =========================
89
 
90
+ SYSTEM_PROMPT = (
91
+ "You are an enterprise Ad Policy Compliance Agent.\n"
92
+ "You navigate a multi-system compliance workflow. Always respond with ONLY valid JSON.\n"
93
+ "\n"
94
+ "REQUIRED PHASE ORDER:\n"
95
+ "1. query_regulations \u2014 always first\n"
96
+ "2. analyze_image \u2014 required for visual/multimodal tasks\n"
97
+ "3. check_advertiser_history or request_landing_page \u2014 as needed\n"
98
+ "4. submit_audit \u2014 always before final decision\n"
99
+ "5. approve or reject \u2014 final decision only after audit\n"
100
+ "\n"
101
+ "AVAILABLE ACTIONS:\n"
102
+ "- query_regulations\n"
103
+ "- analyze_image\n"
104
+ "- check_advertiser_history\n"
105
+ "- request_landing_page\n"
106
+ "- request_id_verification\n"
107
+ "- submit_audit\n"
108
+ "- approve\n"
109
+ "- reject\n"
110
+ "\n"
111
+ "HARD RULES:\n"
112
+ "- NEVER repeat an action listed in `actions_already_taken`.\n"
113
+ "- You MUST progress through the phase order. Do NOT call submit_audit or approve/reject\n"
114
+ " before the prerequisite phases are complete.\n"
115
+ "- Choose your action_type ONLY from the AVAILABLE ACTIONS list above. Any other value is invalid.\n"
116
+ "\n"
117
+ 'Response format:\n'
118
+ '{"action_type": "<action>", "reasoning": "<brief reason>"}'
119
+ )
120
+
121
+ # =========================
122
+ # FIXED AD DATA (deterministic observations per task)
123
+ # =========================
124
+
125
+ TASK_AD_DATA = {
126
+ "task_1_healthcare": {
127
+ "ad_id": "hc-001", "headline": "Secret D1abet3s Cure",
128
+ "body_text": "Reverse it today with this root extract!",
129
+ "advertiser_trust_score": 0.25,
130
+ "targeting_data": {"age_min": 25, "geo": "US"},
131
+ "image_url": "https://mock-meta.com/img/hc-001.jpg",
132
  },
133
+ "task_1_healthcare_b": {
134
+ "ad_id": "hc-002", "headline": "Optimize Metabolic Markers",
135
+ "body_text": "Our proprietary GLP-1 analog ships directly to your door without a visit.",
136
+ "advertiser_trust_score": 0.35,
137
+ "targeting_data": {"age_min": 30, "geo": "US"},
138
+ "image_url": "https://mock-meta.com/img/hc-002.jpg",
139
  },
140
+ "task_2_financial": {
141
+ "ad_id": "fn-001", "headline": "Garanteed 500% Returns",
142
+ "body_text": "Invest in LlamaCoin now. Zero risk.",
143
+ "advertiser_trust_score": 0.20,
144
+ "targeting_data": {"age_min": 21, "geo": "US"},
145
+ "image_url": "https://mock-meta.com/img/fn-001.jpg",
146
  },
147
+ "task_2_financial_b": {
148
+ "ad_id": "fn-002", "headline": "Beat the System",
149
+ "body_text": "Leverage our algorithmic trading bot to exploit market inefficiencies.",
150
+ "advertiser_trust_score": 0.30,
151
+ "targeting_data": {"age_min": 25, "geo": "US"},
152
+ "image_url": "https://mock-meta.com/img/fn-002.jpg",
 
 
 
153
  },
154
+ "task_3_multimodal": {
155
+ "ad_id": "mm-001", "headline": "Achieve Your Dream Body",
156
+ "body_text": "Hard work and our new program will get you there.",
157
+ "advertiser_trust_score": 0.90,
158
+ "targeting_data": {"age_min": 18, "geo": "US"},
159
+ "image_url": "https://mock-meta.com/img/mm-001.jpg",
 
160
  },
161
+ "task_3_multimodal_b": {
162
+ "ad_id": "mm-002", "headline": "Stay Focused All Day",
163
+ "body_text": "Our study supplement helps you crush your exams.",
164
+ "advertiser_trust_score": 0.88,
165
+ "targeting_data": {"age_min": 18, "geo": "US"},
166
+ "image_url": "https://mock-meta.com/img/mm-002.jpg",
 
 
 
 
167
  },
168
+ "task_4_targeting": {
169
+ "ad_id": "tg-001", "headline": "Start Your First Investment Portfolio",
170
+ "body_text": "Learn the basics of stock trading with our beginner-friendly app.",
171
+ "advertiser_trust_score": 0.72,
172
+ "targeting_data": {"age_min": 15, "geo": "US"},
173
+ "image_url": "https://mock-meta.com/img/tg-001.jpg",
 
 
 
 
 
174
  },
175
+ "task_6_conflict": {
176
+ "ad_id": "cf-001", "headline": "Smart Investment Tools",
177
+ "body_text": "Our platform helps you manage your portfolio with ease.",
178
+ "advertiser_trust_score": 0.92,
179
+ "targeting_data": {"age_min": 21, "geo": "US"},
180
+ "image_url": "https://mock-meta.com/img/cf-001.jpg",
 
181
  },
182
+ "task_7_ambiguous": {
183
+ "ad_id": "am-001", "headline": "Feel Better Naturally",
184
+ "body_text": "Our wellness program supports your health journey.",
185
+ "advertiser_trust_score": 0.55,
186
+ "targeting_data": {"age_min": 22, "geo": "US"},
187
+ "image_url": "https://mock-meta.com/img/am-001.jpg",
 
 
188
  },
189
+ "task_8_adversarial": {
190
+ "ad_id": "av-001", "headline": "Natural Energy Boost",
191
+ "body_text": "Feel amazing every day with our daily supplement.",
192
+ "advertiser_trust_score": 0.40,
193
+ "targeting_data": {"age_min": 20, "geo": "US"},
194
+ "image_url": "https://mock-meta.com/img/av-001.jpg",
 
 
 
 
195
  },
196
+ "task_9_dependency_trap": {
197
+ "ad_id": "dt-001", "headline": "Wellness Support Pack",
198
+ "body_text": "Certified vitamins to support your daily routine.",
199
+ "advertiser_trust_score": 0.85,
200
+ "targeting_data": {"age_min": 25, "geo": "US"},
201
+ "image_url": "https://mock-meta.com/img/dt-001.jpg",
202
+ },
203
+ "task_10_failure": {
204
+ "ad_id": "fr-001", "headline": "Miracle Weight Loss Patch",
205
+ "body_text": "Lose 10kg in 2 weeks. Guaranteed results or money back.",
206
+ "advertiser_trust_score": 0.15,
207
+ "targeting_data": {"age_min": 22, "geo": "US"},
208
+ "image_url": "https://mock-meta.com/img/fr-001.jpg",
209
+ },
210
+ }
211
 
212
+ # =========================
213
+ # SCENARIOS (43 unique across all tasks + phases)
214
+ # =========================
 
 
215
 
216
+ _sa = lambda *acts: [{"action_type": a, "reasoning": "setup"} for a in acts]
 
 
217
 
218
+ BASE_SCENARIOS = [
219
+ # ---- TASK 1: Healthcare (obvious violation) ----
220
+ {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
221
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
222
+ "last_feedback": "Ad loaded for task_1_healthcare. Begin with query_regulations.",
223
+ "signals": {}},
224
+ {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare_b",
225
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
226
+ "last_feedback": "Ad loaded for task_1_healthcare. Begin with query_regulations.",
227
+ "signals": {}},
228
+ {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
229
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
230
+ "setup_actions": _sa("query_regulations"),
231
+ "last_feedback": "policy_confidence=0.92",
232
+ "signals": {"policy_confidence": 0.92}},
233
+ {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare_b",
234
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
235
+ "setup_actions": _sa("query_regulations"),
236
+ "last_feedback": "policy_confidence=0.78",
237
+ "signals": {"policy_confidence": 0.78}},
238
+ {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
239
+ "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
240
+ "setup_actions": _sa("query_regulations", "check_advertiser_history"),
241
+ "last_feedback": "risk_score=0.82",
242
+ "signals": {"policy_confidence": 0.92, "risk_score": 0.82}},
243
+ {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
244
+ "step_count": 4,
245
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
246
+ "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
247
+ "last_feedback": "audit_logged id=AUD-001",
248
+ "signals": {"policy_confidence": 0.92, "risk_score": 0.82}},
249
+
250
+ # ---- TASK 2: Financial (obvious violation) ----
251
+ {"task_id": "task_2_financial", "ad_key": "task_2_financial",
252
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
253
+ "last_feedback": "Ad loaded for task_2_financial. Begin with query_regulations.",
254
+ "signals": {}},
255
+ {"task_id": "task_2_financial", "ad_key": "task_2_financial_b",
256
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
257
+ "last_feedback": "Ad loaded for task_2_financial. Begin with query_regulations.",
258
+ "signals": {}},
259
+ {"task_id": "task_2_financial", "ad_key": "task_2_financial",
260
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
261
+ "setup_actions": _sa("query_regulations"),
262
+ "last_feedback": "policy_confidence=0.88",
263
+ "signals": {"policy_confidence": 0.88}},
264
+ {"task_id": "task_2_financial", "ad_key": "task_2_financial",
265
+ "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
266
+ "setup_actions": _sa("query_regulations", "check_advertiser_history"),
267
+ "last_feedback": "risk_score=0.75",
268
+ "signals": {"policy_confidence": 0.88, "risk_score": 0.75}},
269
+ {"task_id": "task_2_financial", "ad_key": "task_2_financial",
270
+ "step_count": 4,
271
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
272
+ "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
273
+ "last_feedback": "audit_logged id=AUD-002",
274
+ "signals": {"policy_confidence": 0.88, "risk_score": 0.75}},
275
+
276
+ # ---- TASK 3: Multimodal (violation hidden in image) ----
277
+ {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
278
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
279
+ "last_feedback": "Ad loaded for task_3_multimodal. Begin with query_regulations.",
280
+ "signals": {}},
281
+ {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal_b",
282
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
283
+ "last_feedback": "Ad loaded for task_3_multimodal. Begin with query_regulations.",
284
+ "signals": {}},
285
+ {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
286
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
287
+ "setup_actions": _sa("query_regulations"),
288
+ "last_feedback": "policy_confidence=0.65",
289
+ "signals": {"policy_confidence": 0.65}},
290
+ {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
291
+ "step_count": 3, "actions_already_taken": ["query_regulations", "analyze_image"],
292
+ "setup_actions": _sa("query_regulations", "analyze_image"),
293
+ "last_feedback": "image_violation_detected",
294
+ "signals": {"policy_confidence": 0.65, "image_flag": True}},
295
+ {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
296
+ "step_count": 4,
297
+ "actions_already_taken": ["query_regulations", "analyze_image", "check_advertiser_history"],
298
+ "setup_actions": _sa("query_regulations", "analyze_image", "check_advertiser_history"),
299
+ "last_feedback": "risk_score=0.45",
300
+ "signals": {"policy_confidence": 0.65, "image_flag": True, "risk_score": 0.45}},
301
+ {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
302
+ "step_count": 5,
303
+ "actions_already_taken": ["query_regulations", "analyze_image", "check_advertiser_history", "submit_audit"],
304
+ "setup_actions": _sa("query_regulations", "analyze_image", "check_advertiser_history", "submit_audit"),
305
+ "last_feedback": "audit_logged id=AUD-003",
306
+ "signals": {"policy_confidence": 0.65, "image_flag": True, "risk_score": 0.45}},
307
+
308
+ # ---- TASK 4: Targeting (minors) ----
309
+ {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
310
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
311
+ "last_feedback": "Ad loaded for task_4_targeting. Begin with query_regulations.",
312
+ "signals": {}},
313
+ {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
314
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
315
+ "setup_actions": _sa("query_regulations"),
316
+ "last_feedback": "policy_confidence=0.70",
317
+ "signals": {"policy_confidence": 0.70}},
318
+ {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
319
+ "step_count": 3, "actions_already_taken": ["query_regulations", "request_id_verification"],
320
+ "setup_actions": _sa("query_regulations", "request_id_verification"),
321
+ "last_feedback": "ALERT: minor targeting age=15",
322
+ "signals": {"policy_confidence": 0.70}},
323
+ {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
324
+ "step_count": 4,
325
+ "actions_already_taken": ["query_regulations", "request_id_verification", "check_advertiser_history"],
326
+ "setup_actions": _sa("query_regulations", "request_id_verification", "check_advertiser_history"),
327
+ "last_feedback": "risk_score=0.60",
328
+ "signals": {"policy_confidence": 0.70, "risk_score": 0.60}},
329
+ {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
330
+ "step_count": 5,
331
+ "actions_already_taken": ["query_regulations", "request_id_verification", "check_advertiser_history", "submit_audit"],
332
+ "setup_actions": _sa("query_regulations", "request_id_verification", "check_advertiser_history", "submit_audit"),
333
+ "last_feedback": "audit_logged id=AUD-004",
334
+ "signals": {"policy_confidence": 0.70, "risk_score": 0.60}},
335
+
336
+ # ---- TASK 6: Conflict (high trust + risky) ----
337
+ {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
338
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
339
+ "last_feedback": "Ad loaded for task_6_conflict. Begin with query_regulations.",
340
+ "signals": {}},
341
+ {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
342
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
343
+ "setup_actions": _sa("query_regulations"),
344
+ "last_feedback": "policy_confidence=0.72",
345
+ "signals": {"policy_confidence": 0.72}},
346
+ {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
347
+ "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
348
+ "setup_actions": _sa("query_regulations", "check_advertiser_history"),
349
+ "last_feedback": "risk_score=0.78",
350
+ "signals": {"policy_confidence": 0.72, "risk_score": 0.78}},
351
+ {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
352
+ "step_count": 4,
353
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
354
+ "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
355
+ "last_feedback": "audit_logged id=AUD-006",
356
+ "signals": {"policy_confidence": 0.72, "risk_score": 0.78}},
357
+
358
+ # ---- TASK 7: Ambiguous (low confidence, need extra signals) ----
359
+ {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
360
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
361
+ "last_feedback": "Ad loaded for task_7_ambiguous. Begin with query_regulations.",
362
+ "signals": {}},
363
+ {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
364
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
365
+ "setup_actions": _sa("query_regulations"),
366
+ "last_feedback": "policy_confidence=0.42",
367
+ "signals": {"policy_confidence": 0.42}},
368
+ {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
369
+ "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
370
+ "setup_actions": _sa("query_regulations", "check_advertiser_history"),
371
+ "last_feedback": "risk_score=0.55",
372
+ "signals": {"policy_confidence": 0.42, "risk_score": 0.55}},
373
+ {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
374
+ "step_count": 4,
375
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_landing_page"],
376
+ "setup_actions": _sa("query_regulations", "check_advertiser_history", "request_landing_page"),
377
+ "last_feedback": "landing_suspicious",
378
+ "signals": {"policy_confidence": 0.42, "risk_score": 0.55, "landing_flag": True}},
379
+ {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
380
+ "step_count": 5,
381
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_landing_page", "submit_audit"],
382
+ "setup_actions": _sa("query_regulations", "check_advertiser_history", "request_landing_page", "submit_audit"),
383
+ "last_feedback": "audit_logged id=AUD-007",
384
+ "signals": {"policy_confidence": 0.42, "risk_score": 0.55, "landing_flag": True}},
385
+
386
+ # ---- TASK 8: Adversarial (fine print in image) ----
387
+ {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
388
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
389
+ "last_feedback": "Ad loaded for task_8_adversarial. Begin with query_regulations.",
390
+ "signals": {}},
391
+ {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
392
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
393
+ "setup_actions": _sa("query_regulations"),
394
+ "last_feedback": "policy_confidence=0.75",
395
+ "signals": {"policy_confidence": 0.75}},
396
+ {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
397
+ "step_count": 3, "actions_already_taken": ["query_regulations", "analyze_image"],
398
+ "setup_actions": _sa("query_regulations", "analyze_image"),
399
+ "last_feedback": "image_violation_detected",
400
+ "signals": {"policy_confidence": 0.75, "image_flag": True}},
401
+ {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
402
+ "step_count": 4,
403
+ "actions_already_taken": ["query_regulations", "analyze_image", "submit_audit"],
404
+ "setup_actions": _sa("query_regulations", "analyze_image", "submit_audit"),
405
+ "last_feedback": "audit_logged id=AUD-008",
406
+ "signals": {"policy_confidence": 0.75, "image_flag": True}},
407
+
408
+ # ---- TASK 9: Dependency Trap (text clean, image has violation) ----
409
+ {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
410
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
411
+ "last_feedback": "Ad loaded for task_9_dependency_trap. Begin with query_regulations.",
412
+ "signals": {}},
413
+ {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
414
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
415
+ "setup_actions": _sa("query_regulations"),
416
+ "last_feedback": "policy_confidence=0.50",
417
+ "signals": {"policy_confidence": 0.50}},
418
+ {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
419
+ "step_count": 3, "actions_already_taken": ["query_regulations", "analyze_image"],
420
+ "setup_actions": _sa("query_regulations", "analyze_image"),
421
+ "last_feedback": "image_violation_detected",
422
+ "signals": {"policy_confidence": 0.50, "image_flag": True}},
423
+ {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
424
+ "step_count": 4,
425
+ "actions_already_taken": ["query_regulations", "analyze_image", "submit_audit"],
426
+ "setup_actions": _sa("query_regulations", "analyze_image", "submit_audit"),
427
+ "last_feedback": "audit_logged id=AUD-009",
428
+ "signals": {"policy_confidence": 0.50, "image_flag": True}},
429
+
430
+ # ---- TASK 10: Failure Recovery ----
431
+ {"task_id": "task_10_failure", "ad_key": "task_10_failure",
432
+ "step_count": 1, "actions_already_taken": [], "setup_actions": [],
433
+ "last_feedback": "Ad loaded for task_10_failure. Begin with query_regulations.",
434
+ "signals": {}},
435
+ {"task_id": "task_10_failure", "ad_key": "task_10_failure",
436
+ "step_count": 2, "actions_already_taken": ["query_regulations"],
437
+ "setup_actions": _sa("query_regulations"),
438
+ "last_feedback": "policy_confidence=0.85",
439
+ "signals": {"policy_confidence": 0.85}},
440
+ {"task_id": "task_10_failure", "ad_key": "task_10_failure",
441
+ "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
442
+ "setup_actions": _sa("query_regulations", "check_advertiser_history"),
443
+ "last_feedback": "risk_score=0.80",
444
+ "signals": {"policy_confidence": 0.85, "risk_score": 0.80}},
445
+ {"task_id": "task_10_failure", "ad_key": "task_10_failure",
446
+ "step_count": 4,
447
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
448
+ "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
449
+ "last_feedback": "audit_logged id=AUD-010",
450
+ "signals": {"policy_confidence": 0.85, "risk_score": 0.80}},
451
+ ]
452
 
453
+ # =========================
454
+ # DATASET BUILDER
455
+ # =========================
456
 
457
+ def build_observation(scenario):
458
+ """Construct observation JSON matching inference.py format."""
459
+ ad = TASK_AD_DATA[scenario["ad_key"]]
460
+ sigs = scenario.get("signals", {})
461
+ return {
462
+ "task_id": scenario["task_id"],
463
+ "last_feedback": scenario["last_feedback"],
464
+ "step_count": scenario["step_count"],
465
+ "actions_already_taken": scenario["actions_already_taken"],
466
+ "ad_details": {
467
+ **ad,
468
+ "status_message": scenario["last_feedback"],
469
+ "reward": 0.0,
470
+ "done": False,
471
+ "risk_score": sigs.get("risk_score"),
472
+ "policy_confidence": sigs.get("policy_confidence"),
473
+ "image_flag": sigs.get("image_flag"),
474
+ "landing_flag": sigs.get("landing_flag"),
475
+ "last_error": sigs.get("last_error"),
476
+ },
477
+ }
478
 
479
 
480
  def build_dataset():
481
  rows = []
482
  for s in BASE_SCENARIOS:
483
+ obs = build_observation(s)
484
+ user_content = (
485
+ "Current Ad Observation:\n"
486
+ + json.dumps(obs, indent=2)
487
+ + "\n\nWhat is your next action?"
488
  )
489
  rows.append({
490
+ "prompt": [
491
+ {"role": "system", "content": SYSTEM_PROMPT},
492
+ {"role": "user", "content": user_content},
493
+ ],
494
  "task_id": s["task_id"],
495
  "setup_actions": s["setup_actions"],
496
  })
497
+ return Dataset.from_list(rows * 8)
498
 
499
  # =========================
500
+ # REWARD FUNCTION
501
  # =========================
502
 
 
 
503
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
504
+ """Shaped reward with phase-specific bonuses for meaningful GRPO gradients."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  client = EnvClient(ENV_URL)
506
  rewards = []
507
 
508
  if task_id is None or setup_actions is None:
 
 
 
509
  return [-1.0] * len(completions)
510
 
511
  for idx, (completion, t_id, setup) in enumerate(zip(completions, task_id, setup_actions)):
512
  parsed = extract_json(completion)
 
 
 
513
  if not parsed:
514
  rewards.append(-1.0)
515
  continue
 
525
  }
526
 
527
  try:
528
+ random.seed(hash((t_id, len(setup))) % (2**32 - 1))
529
  client.reset(t_id)
530
  for s in setup:
531
  safe_step(client, s)
 
545
  else:
546
  shaped = 0.5 + env_reward
547
 
548
+ taken = set(a["action_type"] for a in setup)
549
+
550
+ if not taken:
551
+ if action_type == "query_regulations":
552
+ shaped += 0.15
553
+ elif "submit_audit" in taken:
554
+ if action_type in ("approve", "reject"):
555
+ shaped += 0.2
556
+ else:
557
+ shaped -= 0.1
558
+ elif "query_regulations" in taken:
559
+ gathering = {
560
+ "analyze_image", "check_advertiser_history",
561
+ "request_landing_page", "request_id_verification",
562
+ }
563
+ if action_type in gathering:
564
+ shaped += 0.1
565
+ elif action_type == "submit_audit":
566
+ shaped += 0.1
567
+ elif action_type in ("approve", "reject"):
568
+ shaped -= 0.15
569
+
570
+ if t_id == "task_3_multimodal" and action_type == "analyze_image":
571
+ shaped += 0.1
572
+ if t_id == "task_4_targeting" and action_type == "request_id_verification":
573
+ shaped += 0.1
574
+ if t_id in ("task_8_adversarial", "task_9_dependency_trap") and action_type == "analyze_image":
575
+ shaped += 0.1
576
+
577
  rewards.append(shaped)
578
 
579
  except Exception:
 
589
  _props = torch.cuda.get_device_properties(0)
590
  _vram = _props.total_memory
591
  _name = _props.name
592
+ _cc = (_props.major, _props.minor)
593
  print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}")
594
  else:
595
  _vram = 0
596
  _name = "CPU"
597
  _cc = (0, 0)
598
 
599
+ USE_4BIT = _vram < 40 * 1024**3
600
+ USE_BF16 = _cc >= (8, 0) and not USE_4BIT
 
 
 
 
601
 
602
  model, tokenizer = FastLanguageModel.from_pretrained(
603
  model_name="unsloth/Llama-3.1-8B-Instruct",
 
626
 
627
  dataset = build_dataset()
628
 
 
 
 
 
629
  trainer = GRPOTrainer(
630
  model=model,
631
  reward_funcs=[reward_environment],
632
  args=GRPOConfig(
633
  output_dir="outputs",
634
+ learning_rate=5e-6,
635
+ num_train_epochs=1 if USE_4BIT else 2,
636
  per_device_train_batch_size=1 if USE_4BIT else 2,
637
+ gradient_accumulation_steps=4,
638
+ num_generations=4,
639
+ max_prompt_length=512,
640
+ max_completion_length=80,
641
+ logging_steps=5,
642
+ warmup_steps=10,
643
  bf16=USE_BF16,
644
  fp16=not USE_BF16,
645
  report_to="none",
 
655
  if __name__ == "__main__":
656
  ensure_env_ready()
657
 
 
 
 
658
  print("Starting GRPO training...")
659
  trainer.train()
660
 
 
662
  tokenizer.save_pretrained("outputs/lora_adapter")
663
  print("LoRA adapter saved to outputs/lora_adapter")
664
 
665
+ print("Merging adapter into base model...")
666
  merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
667
  model_name="outputs/lora_adapter",
668
  load_in_4bit=False,