parth-1 commited on
Commit
a5a9c5a
·
verified ·
1 Parent(s): f499944

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +149 -479
grpo_train.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import os
4
 
 
5
  os.environ.setdefault("USER", "user")
6
  os.environ.setdefault("HOME", "/tmp/home")
7
  os.environ.setdefault("HF_HOME", "/tmp/hf_home")
@@ -22,8 +23,6 @@ import time
22
  import json
23
  import random
24
  import requests
25
- from requests.adapters import HTTPAdapter
26
- from urllib3.util.retry import Retry
27
  import torch
28
 
29
  from datasets import Dataset
@@ -37,15 +36,19 @@ PatchFastRL("GRPO", FastLanguageModel)
37
  # =========================
38
 
39
  OUTPUT_DIR = "/tmp/outputs"
40
-
41
  ENV_URL = os.getenv("ENV_URL", "https://parth-1-metaguard.hf.space")
42
  HF_TOKEN = os.getenv("HF_TOKEN", "")
43
- HF_REPO = os.getenv("HF_REPO", "")
44
 
45
  ALLOWED_ACTIONS = [
46
- "query_regulations", "analyze_image", "check_advertiser_history",
47
- "request_landing_page", "request_id_verification",
48
- "submit_audit", "approve", "reject",
 
 
 
 
 
49
  ]
50
 
51
  # =========================
@@ -53,57 +56,54 @@ ALLOWED_ACTIONS = [
53
  # =========================
54
 
55
  def ensure_env_ready():
56
- for i in range(60):
57
  try:
58
  r = requests.post(
59
  f"{ENV_URL}/reset",
60
  json={"task_id": "task_1_healthcare"},
61
- timeout=15,
62
  )
63
  if r.status_code == 200:
64
- print(f"Environment ready (attempt {i+1})")
65
  return
66
- print(f"Env returned {r.status_code}, retrying...")
67
- except Exception as e:
68
- if i % 10 == 0:
69
- print(f"Env not ready (attempt {i+1}): {e}")
70
- time.sleep(2)
71
- raise RuntimeError("ENV not reachable after 60 attempts")
72
 
73
  # =========================
74
- # ENV CLIENT
75
  # =========================
76
 
77
- def _build_session():
78
- s = requests.Session()
79
- retries = Retry(total=3, backoff_factor=0.5,
80
- status_forcelist=[500, 502, 503, 504])
81
- s.mount("https://", HTTPAdapter(max_retries=retries))
82
- s.mount("http://", HTTPAdapter(max_retries=retries))
83
- return s
84
-
85
  class EnvClient:
86
  def __init__(self, url):
87
  self.url = url
88
- self.session = _build_session()
89
 
90
  def reset(self, task_id):
91
- return self.session.post(
92
- f"{self.url}/reset", json={"task_id": task_id}, timeout=15,
 
 
93
  ).json()
94
 
95
  def step(self, action):
96
- return self.session.post(
97
- f"{self.url}/step", json={"action": action}, timeout=15,
 
 
98
  ).json()
99
 
100
  def safe_step(client, action):
101
- for attempt in range(5):
102
  try:
103
  return client.step(action)
104
- except Exception:
105
- time.sleep(1.0 * (attempt + 1))
106
- return {"reward": -0.3, "status_message": "timeout"}
 
 
 
 
107
 
108
  def extract_json(text):
109
  try:
@@ -112,435 +112,146 @@ def extract_json(text):
112
  if text.startswith("json"):
113
  text = text[4:]
114
  return json.loads(text.strip())
115
- except Exception:
116
  return None
117
 
118
  # =========================
119
- # SYSTEM PROMPT (identical to inference.py)
120
  # =========================
121
 
122
- SYSTEM_PROMPT = (
123
- "You are an enterprise Ad Policy Compliance Agent.\n"
124
- "You navigate a multi-system compliance workflow. Always respond with ONLY valid JSON.\n"
125
- "\n"
126
- "REQUIRED PHASE ORDER:\n"
127
- "1. query_regulations \u2014 always first\n"
128
- "2. analyze_image \u2014 required for visual/multimodal tasks\n"
129
- "3. check_advertiser_history or request_landing_page \u2014 as needed\n"
130
- "4. submit_audit \u2014 always before final decision\n"
131
- "5. approve or reject \u2014 final decision only after audit\n"
132
- "\n"
133
- "AVAILABLE ACTIONS:\n"
134
- "- query_regulations\n"
135
- "- analyze_image\n"
136
- "- check_advertiser_history\n"
137
- "- request_landing_page\n"
138
- "- request_id_verification\n"
139
- "- submit_audit\n"
140
- "- approve\n"
141
- "- reject\n"
142
- "\n"
143
- "HARD RULES:\n"
144
- "- NEVER repeat an action listed in `actions_already_taken`.\n"
145
- "- You MUST progress through the phase order. Do NOT call submit_audit or approve/reject\n"
146
- " before the prerequisite phases are complete.\n"
147
- "- Choose your action_type ONLY from the AVAILABLE ACTIONS list above. Any other value is invalid.\n"
148
- "\n"
149
- 'Response format:\n'
150
- '{"action_type": "<action>", "reasoning": "<brief reason>"}'
151
- )
152
-
153
- # =========================
154
- # FIXED AD DATA (deterministic observations per task)
155
- # =========================
156
-
157
- TASK_AD_DATA = {
158
- "task_1_healthcare": {
159
- "ad_id": "hc-001", "headline": "Secret D1abet3s Cure",
160
- "body_text": "Reverse it today with this root extract!",
161
- "advertiser_trust_score": 0.25,
162
- "targeting_data": {"age_min": 25, "geo": "US"},
163
- "image_url": "https://mock-meta.com/img/hc-001.jpg",
164
- },
165
- "task_1_healthcare_b": {
166
- "ad_id": "hc-002", "headline": "Optimize Metabolic Markers",
167
- "body_text": "Our proprietary GLP-1 analog ships directly to your door without a visit.",
168
- "advertiser_trust_score": 0.35,
169
- "targeting_data": {"age_min": 30, "geo": "US"},
170
- "image_url": "https://mock-meta.com/img/hc-002.jpg",
171
- },
172
- "task_2_financial": {
173
- "ad_id": "fn-001", "headline": "Garanteed 500% Returns",
174
- "body_text": "Invest in LlamaCoin now. Zero risk.",
175
- "advertiser_trust_score": 0.20,
176
- "targeting_data": {"age_min": 21, "geo": "US"},
177
- "image_url": "https://mock-meta.com/img/fn-001.jpg",
178
  },
179
- "task_2_financial_b": {
180
- "ad_id": "fn-002", "headline": "Beat the System",
181
- "body_text": "Leverage our algorithmic trading bot to exploit market inefficiencies.",
182
- "advertiser_trust_score": 0.30,
183
- "targeting_data": {"age_min": 25, "geo": "US"},
184
- "image_url": "https://mock-meta.com/img/fn-002.jpg",
185
  },
186
- "task_3_multimodal": {
187
- "ad_id": "mm-001", "headline": "Achieve Your Dream Body",
188
- "body_text": "Hard work and our new program will get you there.",
189
- "advertiser_trust_score": 0.90,
190
- "targeting_data": {"age_min": 18, "geo": "US"},
191
- "image_url": "https://mock-meta.com/img/mm-001.jpg",
192
  },
193
- "task_3_multimodal_b": {
194
- "ad_id": "mm-002", "headline": "Stay Focused All Day",
195
- "body_text": "Our study supplement helps you crush your exams.",
196
- "advertiser_trust_score": 0.88,
197
- "targeting_data": {"age_min": 18, "geo": "US"},
198
- "image_url": "https://mock-meta.com/img/mm-002.jpg",
 
199
  },
200
- "task_4_targeting": {
201
- "ad_id": "tg-001", "headline": "Start Your First Investment Portfolio",
202
- "body_text": "Learn the basics of stock trading with our beginner-friendly app.",
203
- "advertiser_trust_score": 0.72,
204
- "targeting_data": {"age_min": 15, "geo": "US"},
205
- "image_url": "https://mock-meta.com/img/tg-001.jpg",
 
206
  },
207
- "task_6_conflict": {
208
- "ad_id": "cf-001", "headline": "Smart Investment Tools",
209
- "body_text": "Our platform helps you manage your portfolio with ease.",
210
- "advertiser_trust_score": 0.92,
211
- "targeting_data": {"age_min": 21, "geo": "US"},
212
- "image_url": "https://mock-meta.com/img/cf-001.jpg",
 
 
213
  },
214
- "task_7_ambiguous": {
215
- "ad_id": "am-001", "headline": "Feel Better Naturally",
216
- "body_text": "Our wellness program supports your health journey.",
217
- "advertiser_trust_score": 0.55,
218
- "targeting_data": {"age_min": 22, "geo": "US"},
219
- "image_url": "https://mock-meta.com/img/am-001.jpg",
 
 
 
220
  },
221
- "task_8_adversarial": {
222
- "ad_id": "av-001", "headline": "Natural Energy Boost",
223
- "body_text": "Feel amazing every day with our daily supplement.",
224
- "advertiser_trust_score": 0.40,
225
- "targeting_data": {"age_min": 20, "geo": "US"},
226
- "image_url": "https://mock-meta.com/img/av-001.jpg",
227
  },
228
- "task_9_dependency_trap": {
229
- "ad_id": "dt-001", "headline": "Wellness Support Pack",
230
- "body_text": "Certified vitamins to support your daily routine.",
231
- "advertiser_trust_score": 0.85,
232
- "targeting_data": {"age_min": 25, "geo": "US"},
233
- "image_url": "https://mock-meta.com/img/dt-001.jpg",
 
234
  },
235
- "task_10_failure": {
236
- "ad_id": "fr-001", "headline": "Miracle Weight Loss Patch",
237
- "body_text": "Lose 10kg in 2 weeks. Guaranteed results or money back.",
238
- "advertiser_trust_score": 0.15,
239
- "targeting_data": {"age_min": 22, "geo": "US"},
240
- "image_url": "https://mock-meta.com/img/fr-001.jpg",
 
 
 
241
  },
242
- }
243
 
244
- # =========================
245
- # SCENARIOS (43 unique across all tasks + phases)
246
- # =========================
247
 
248
- _sa = lambda *acts: [{"action_type": a, "reasoning": "setup"} for a in acts]
 
 
 
 
 
 
 
 
249
 
250
- BASE_SCENARIOS = [
251
- # ---- TASK 1: Healthcare (obvious violation) ----
252
- {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
253
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
254
- "last_feedback": "Ad loaded for task_1_healthcare. Begin with query_regulations.",
255
- "signals": {}},
256
- {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare_b",
257
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
258
- "last_feedback": "Ad loaded for task_1_healthcare. Begin with query_regulations.",
259
- "signals": {}},
260
- {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
261
- "step_count": 2, "actions_already_taken": ["query_regulations"],
262
- "setup_actions": _sa("query_regulations"),
263
- "last_feedback": "policy_confidence=0.92",
264
- "signals": {"policy_confidence": 0.92}},
265
- {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare_b",
266
- "step_count": 2, "actions_already_taken": ["query_regulations"],
267
- "setup_actions": _sa("query_regulations"),
268
- "last_feedback": "policy_confidence=0.78",
269
- "signals": {"policy_confidence": 0.78}},
270
- {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
271
- "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
272
- "setup_actions": _sa("query_regulations", "check_advertiser_history"),
273
- "last_feedback": "risk_score=0.82",
274
- "signals": {"policy_confidence": 0.92, "risk_score": 0.82}},
275
- {"task_id": "task_1_healthcare", "ad_key": "task_1_healthcare",
276
- "step_count": 4,
277
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
278
- "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
279
- "last_feedback": "audit_logged id=AUD-001",
280
- "signals": {"policy_confidence": 0.92, "risk_score": 0.82}},
281
-
282
- # ---- TASK 2: Financial (obvious violation) ----
283
- {"task_id": "task_2_financial", "ad_key": "task_2_financial",
284
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
285
- "last_feedback": "Ad loaded for task_2_financial. Begin with query_regulations.",
286
- "signals": {}},
287
- {"task_id": "task_2_financial", "ad_key": "task_2_financial_b",
288
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
289
- "last_feedback": "Ad loaded for task_2_financial. Begin with query_regulations.",
290
- "signals": {}},
291
- {"task_id": "task_2_financial", "ad_key": "task_2_financial",
292
- "step_count": 2, "actions_already_taken": ["query_regulations"],
293
- "setup_actions": _sa("query_regulations"),
294
- "last_feedback": "policy_confidence=0.88",
295
- "signals": {"policy_confidence": 0.88}},
296
- {"task_id": "task_2_financial", "ad_key": "task_2_financial",
297
- "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
298
- "setup_actions": _sa("query_regulations", "check_advertiser_history"),
299
- "last_feedback": "risk_score=0.75",
300
- "signals": {"policy_confidence": 0.88, "risk_score": 0.75}},
301
- {"task_id": "task_2_financial", "ad_key": "task_2_financial",
302
- "step_count": 4,
303
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
304
- "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
305
- "last_feedback": "audit_logged id=AUD-002",
306
- "signals": {"policy_confidence": 0.88, "risk_score": 0.75}},
307
-
308
- # ---- TASK 3: Multimodal (violation hidden in image) ----
309
- {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
310
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
311
- "last_feedback": "Ad loaded for task_3_multimodal. Begin with query_regulations.",
312
- "signals": {}},
313
- {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal_b",
314
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
315
- "last_feedback": "Ad loaded for task_3_multimodal. Begin with query_regulations.",
316
- "signals": {}},
317
- {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
318
- "step_count": 2, "actions_already_taken": ["query_regulations"],
319
- "setup_actions": _sa("query_regulations"),
320
- "last_feedback": "policy_confidence=0.65",
321
- "signals": {"policy_confidence": 0.65}},
322
- {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
323
- "step_count": 3, "actions_already_taken": ["query_regulations", "analyze_image"],
324
- "setup_actions": _sa("query_regulations", "analyze_image"),
325
- "last_feedback": "image_violation_detected",
326
- "signals": {"policy_confidence": 0.65, "image_flag": True}},
327
- {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
328
- "step_count": 4,
329
- "actions_already_taken": ["query_regulations", "analyze_image", "check_advertiser_history"],
330
- "setup_actions": _sa("query_regulations", "analyze_image", "check_advertiser_history"),
331
- "last_feedback": "risk_score=0.45",
332
- "signals": {"policy_confidence": 0.65, "image_flag": True, "risk_score": 0.45}},
333
- {"task_id": "task_3_multimodal", "ad_key": "task_3_multimodal",
334
- "step_count": 5,
335
- "actions_already_taken": ["query_regulations", "analyze_image", "check_advertiser_history", "submit_audit"],
336
- "setup_actions": _sa("query_regulations", "analyze_image", "check_advertiser_history", "submit_audit"),
337
- "last_feedback": "audit_logged id=AUD-003",
338
- "signals": {"policy_confidence": 0.65, "image_flag": True, "risk_score": 0.45}},
339
-
340
- # ---- TASK 4: Targeting (minors) ----
341
- {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
342
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
343
- "last_feedback": "Ad loaded for task_4_targeting. Begin with query_regulations.",
344
- "signals": {}},
345
- {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
346
- "step_count": 2, "actions_already_taken": ["query_regulations"],
347
- "setup_actions": _sa("query_regulations"),
348
- "last_feedback": "policy_confidence=0.70",
349
- "signals": {"policy_confidence": 0.70}},
350
- {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
351
- "step_count": 3, "actions_already_taken": ["query_regulations", "request_id_verification"],
352
- "setup_actions": _sa("query_regulations", "request_id_verification"),
353
- "last_feedback": "ALERT: minor targeting age=15",
354
- "signals": {"policy_confidence": 0.70}},
355
- {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
356
- "step_count": 4,
357
- "actions_already_taken": ["query_regulations", "request_id_verification", "check_advertiser_history"],
358
- "setup_actions": _sa("query_regulations", "request_id_verification", "check_advertiser_history"),
359
- "last_feedback": "risk_score=0.60",
360
- "signals": {"policy_confidence": 0.70, "risk_score": 0.60}},
361
- {"task_id": "task_4_targeting", "ad_key": "task_4_targeting",
362
- "step_count": 5,
363
- "actions_already_taken": ["query_regulations", "request_id_verification", "check_advertiser_history", "submit_audit"],
364
- "setup_actions": _sa("query_regulations", "request_id_verification", "check_advertiser_history", "submit_audit"),
365
- "last_feedback": "audit_logged id=AUD-004",
366
- "signals": {"policy_confidence": 0.70, "risk_score": 0.60}},
367
-
368
- # ---- TASK 6: Conflict (high trust + risky) ----
369
- {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
370
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
371
- "last_feedback": "Ad loaded for task_6_conflict. Begin with query_regulations.",
372
- "signals": {}},
373
- {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
374
- "step_count": 2, "actions_already_taken": ["query_regulations"],
375
- "setup_actions": _sa("query_regulations"),
376
- "last_feedback": "policy_confidence=0.72",
377
- "signals": {"policy_confidence": 0.72}},
378
- {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
379
- "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
380
- "setup_actions": _sa("query_regulations", "check_advertiser_history"),
381
- "last_feedback": "risk_score=0.78",
382
- "signals": {"policy_confidence": 0.72, "risk_score": 0.78}},
383
- {"task_id": "task_6_conflict", "ad_key": "task_6_conflict",
384
- "step_count": 4,
385
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
386
- "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
387
- "last_feedback": "audit_logged id=AUD-006",
388
- "signals": {"policy_confidence": 0.72, "risk_score": 0.78}},
389
-
390
- # ---- TASK 7: Ambiguous (low confidence, need extra signals) ----
391
- {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
392
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
393
- "last_feedback": "Ad loaded for task_7_ambiguous. Begin with query_regulations.",
394
- "signals": {}},
395
- {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
396
- "step_count": 2, "actions_already_taken": ["query_regulations"],
397
- "setup_actions": _sa("query_regulations"),
398
- "last_feedback": "policy_confidence=0.42",
399
- "signals": {"policy_confidence": 0.42}},
400
- {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
401
- "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
402
- "setup_actions": _sa("query_regulations", "check_advertiser_history"),
403
- "last_feedback": "risk_score=0.55",
404
- "signals": {"policy_confidence": 0.42, "risk_score": 0.55}},
405
- {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
406
- "step_count": 4,
407
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_landing_page"],
408
- "setup_actions": _sa("query_regulations", "check_advertiser_history", "request_landing_page"),
409
- "last_feedback": "landing_suspicious",
410
- "signals": {"policy_confidence": 0.42, "risk_score": 0.55, "landing_flag": True}},
411
- {"task_id": "task_7_ambiguous", "ad_key": "task_7_ambiguous",
412
- "step_count": 5,
413
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_landing_page", "submit_audit"],
414
- "setup_actions": _sa("query_regulations", "check_advertiser_history", "request_landing_page", "submit_audit"),
415
- "last_feedback": "audit_logged id=AUD-007",
416
- "signals": {"policy_confidence": 0.42, "risk_score": 0.55, "landing_flag": True}},
417
-
418
- # ---- TASK 8: Adversarial (fine print in image) ----
419
- {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
420
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
421
- "last_feedback": "Ad loaded for task_8_adversarial. Begin with query_regulations.",
422
- "signals": {}},
423
- {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
424
- "step_count": 2, "actions_already_taken": ["query_regulations"],
425
- "setup_actions": _sa("query_regulations"),
426
- "last_feedback": "policy_confidence=0.75",
427
- "signals": {"policy_confidence": 0.75}},
428
- {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
429
- "step_count": 3, "actions_already_taken": ["query_regulations", "analyze_image"],
430
- "setup_actions": _sa("query_regulations", "analyze_image"),
431
- "last_feedback": "image_violation_detected",
432
- "signals": {"policy_confidence": 0.75, "image_flag": True}},
433
- {"task_id": "task_8_adversarial", "ad_key": "task_8_adversarial",
434
- "step_count": 4,
435
- "actions_already_taken": ["query_regulations", "analyze_image", "submit_audit"],
436
- "setup_actions": _sa("query_regulations", "analyze_image", "submit_audit"),
437
- "last_feedback": "audit_logged id=AUD-008",
438
- "signals": {"policy_confidence": 0.75, "image_flag": True}},
439
-
440
- # ---- TASK 9: Dependency Trap (text clean, image has violation) ----
441
- {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
442
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
443
- "last_feedback": "Ad loaded for task_9_dependency_trap. Begin with query_regulations.",
444
- "signals": {}},
445
- {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
446
- "step_count": 2, "actions_already_taken": ["query_regulations"],
447
- "setup_actions": _sa("query_regulations"),
448
- "last_feedback": "policy_confidence=0.50",
449
- "signals": {"policy_confidence": 0.50}},
450
- {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
451
- "step_count": 3, "actions_already_taken": ["query_regulations", "analyze_image"],
452
- "setup_actions": _sa("query_regulations", "analyze_image"),
453
- "last_feedback": "image_violation_detected",
454
- "signals": {"policy_confidence": 0.50, "image_flag": True}},
455
- {"task_id": "task_9_dependency_trap", "ad_key": "task_9_dependency_trap",
456
- "step_count": 4,
457
- "actions_already_taken": ["query_regulations", "analyze_image", "submit_audit"],
458
- "setup_actions": _sa("query_regulations", "analyze_image", "submit_audit"),
459
- "last_feedback": "audit_logged id=AUD-009",
460
- "signals": {"policy_confidence": 0.50, "image_flag": True}},
461
-
462
- # ---- TASK 10: Failure Recovery ----
463
- {"task_id": "task_10_failure", "ad_key": "task_10_failure",
464
- "step_count": 1, "actions_already_taken": [], "setup_actions": [],
465
- "last_feedback": "Ad loaded for task_10_failure. Begin with query_regulations.",
466
- "signals": {}},
467
- {"task_id": "task_10_failure", "ad_key": "task_10_failure",
468
- "step_count": 2, "actions_already_taken": ["query_regulations"],
469
- "setup_actions": _sa("query_regulations"),
470
- "last_feedback": "policy_confidence=0.85",
471
- "signals": {"policy_confidence": 0.85}},
472
- {"task_id": "task_10_failure", "ad_key": "task_10_failure",
473
- "step_count": 3, "actions_already_taken": ["query_regulations", "check_advertiser_history"],
474
- "setup_actions": _sa("query_regulations", "check_advertiser_history"),
475
- "last_feedback": "risk_score=0.80",
476
- "signals": {"policy_confidence": 0.85, "risk_score": 0.80}},
477
- {"task_id": "task_10_failure", "ad_key": "task_10_failure",
478
- "step_count": 4,
479
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
480
- "setup_actions": _sa("query_regulations", "check_advertiser_history", "submit_audit"),
481
- "last_feedback": "audit_logged id=AUD-010",
482
- "signals": {"policy_confidence": 0.85, "risk_score": 0.80}},
483
- ]
484
 
485
- # =========================
486
- # DATASET BUILDER
487
- # =========================
 
 
 
488
 
489
- def build_observation(scenario):
490
- """Construct observation JSON matching inference.py format."""
491
- ad = TASK_AD_DATA[scenario["ad_key"]]
492
- sigs = scenario.get("signals", {})
493
- return {
494
- "task_id": scenario["task_id"],
495
- "last_feedback": scenario["last_feedback"],
496
- "step_count": scenario["step_count"],
497
- "actions_already_taken": scenario["actions_already_taken"],
498
- "ad_details": {
499
- **ad,
500
- "status_message": scenario["last_feedback"],
501
- "reward": 0.0,
502
- "done": False,
503
- "risk_score": sigs.get("risk_score"),
504
- "policy_confidence": sigs.get("policy_confidence"),
505
- "image_flag": sigs.get("image_flag"),
506
- "landing_flag": sigs.get("landing_flag"),
507
- "last_error": sigs.get("last_error"),
508
- },
509
- }
510
 
 
511
 
512
  def build_dataset():
513
  rows = []
514
  for s in BASE_SCENARIOS:
515
- obs = build_observation(s)
516
- user_content = (
517
- "Current Ad Observation:\n"
518
- + json.dumps(obs, indent=2)
519
- + "\n\nWhat is your next action?"
520
  )
521
  rows.append({
522
- "prompt": [
523
- {"role": "system", "content": SYSTEM_PROMPT},
524
- {"role": "user", "content": user_content},
525
- ],
526
  "task_id": s["task_id"],
527
  "setup_actions": s["setup_actions"],
528
  })
529
- return Dataset.from_list(rows * 8)
530
 
531
  # =========================
532
  # REWARD FUNCTION
533
  # =========================
534
 
535
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
536
- """Shaped reward with phase-specific bonuses for meaningful GRPO gradients."""
537
  client = EnvClient(ENV_URL)
538
  rewards = []
539
 
540
- if task_id is None or setup_actions is None:
541
- return [-1.0] * len(completions)
542
-
543
- for idx, (completion, t_id, setup) in enumerate(zip(completions, task_id, setup_actions)):
544
  parsed = extract_json(completion)
545
  if not parsed:
546
  rewards.append(-1.0)
@@ -557,7 +268,6 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
557
  }
558
 
559
  try:
560
- random.seed(hash((t_id, len(setup))) % (2**32 - 1))
561
  client.reset(t_id)
562
  for s in setup:
563
  safe_step(client, s)
@@ -577,35 +287,6 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
577
  else:
578
  shaped = 0.5 + env_reward
579
 
580
- taken = set(a["action_type"] for a in setup)
581
-
582
- if not taken:
583
- if action_type == "query_regulations":
584
- shaped += 0.15
585
- elif "submit_audit" in taken:
586
- if action_type in ("approve", "reject"):
587
- shaped += 0.2
588
- else:
589
- shaped -= 0.1
590
- elif "query_regulations" in taken:
591
- gathering = {
592
- "analyze_image", "check_advertiser_history",
593
- "request_landing_page", "request_id_verification",
594
- }
595
- if action_type in gathering:
596
- shaped += 0.1
597
- elif action_type == "submit_audit":
598
- shaped += 0.1
599
- elif action_type in ("approve", "reject"):
600
- shaped -= 0.15
601
-
602
- if t_id == "task_3_multimodal" and action_type == "analyze_image":
603
- shaped += 0.1
604
- if t_id == "task_4_targeting" and action_type == "request_id_verification":
605
- shaped += 0.1
606
- if t_id in ("task_8_adversarial", "task_9_dependency_trap") and action_type == "analyze_image":
607
- shaped += 0.1
608
-
609
  rewards.append(shaped)
610
 
611
  except Exception:
@@ -617,22 +298,11 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
617
  # MODEL
618
  # =========================
619
 
620
- if torch.cuda.is_available():
621
- _props = torch.cuda.get_device_properties(0)
622
- _vram = _props.total_memory
623
- _name = _props.name
624
- _cc = (_props.major, _props.minor)
625
- print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}")
626
- else:
627
- _vram = 0
628
- _name = "CPU"
629
- _cc = (0, 0)
630
-
631
  model, tokenizer = FastLanguageModel.from_pretrained(
632
- model_name="unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
633
- load_in_4bit=True,
634
  max_seq_length=2048,
635
- dtype=None,
636
  )
637
 
638
  model = FastLanguageModel.get_peft_model(
@@ -660,15 +330,15 @@ trainer = GRPOTrainer(
660
  reward_funcs=[reward_environment],
661
  args=GRPOConfig(
662
  output_dir=OUTPUT_DIR,
663
- learning_rate=5e-6,
664
  num_train_epochs=3,
665
- per_device_train_batch_size=4,
666
- gradient_accumulation_steps=4,
667
- num_generations=8,
668
- max_prompt_length=512,
669
- max_completion_length=80,
670
  logging_steps=5,
671
- warmup_steps=10,
672
  bf16=True,
673
  fp16=False,
674
  report_to="none",
@@ -691,17 +361,17 @@ if __name__ == "__main__":
691
  try:
692
  trainer.train()
693
  except torch.cuda.OutOfMemoryError:
694
- print("OOM detected! Clearing cache and retrying with batch_size=2, num_generations=4...")
695
  torch.cuda.empty_cache()
696
- trainer.args.per_device_train_batch_size = 2
697
- trainer.args.num_generations = 4
698
  trainer.train()
699
 
700
  model.save_pretrained(LORA_DIR)
701
  tokenizer.save_pretrained(LORA_DIR)
702
  print(f"LoRA adapter saved to {LORA_DIR}")
703
 
704
- print("Merging adapter into base model...")
705
  merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
706
  model_name=LORA_DIR,
707
  load_in_4bit=False,
 
2
 
3
  import os
4
 
5
+ # Route all caches to /tmp/ to avoid Hugging Face Spaces Read-Only errors
6
  os.environ.setdefault("USER", "user")
7
  os.environ.setdefault("HOME", "/tmp/home")
8
  os.environ.setdefault("HF_HOME", "/tmp/hf_home")
 
23
  import json
24
  import random
25
  import requests
 
 
26
  import torch
27
 
28
  from datasets import Dataset
 
36
  # =========================
37
 
38
  OUTPUT_DIR = "/tmp/outputs"
 
39
  ENV_URL = os.getenv("ENV_URL", "https://parth-1-metaguard.hf.space")
40
  HF_TOKEN = os.getenv("HF_TOKEN", "")
41
+ HF_REPO = os.getenv("HF_REPO", "") # e.g. "parth-1/metaguard-llama3.1-8b-grpo"
42
 
43
  ALLOWED_ACTIONS = [
44
+ "query_regulations",
45
+ "analyze_image",
46
+ "check_advertiser_history",
47
+ "request_landing_page",
48
+ "request_id_verification",
49
+ "submit_audit",
50
+ "approve",
51
+ "reject",
52
  ]
53
 
54
  # =========================
 
56
  # =========================
57
 
58
  def ensure_env_ready():
59
+ for _ in range(20):
60
  try:
61
  r = requests.post(
62
  f"{ENV_URL}/reset",
63
  json={"task_id": "task_1_healthcare"},
64
+ timeout=5
65
  )
66
  if r.status_code == 200:
67
+ print("Environment ready")
68
  return
69
+ except:
70
+ pass
71
+ time.sleep(1)
72
+ raise RuntimeError(" ENV not reachable")
 
 
73
 
74
  # =========================
75
+ # SAFE CLIENT
76
  # =========================
77
 
 
 
 
 
 
 
 
 
78
  class EnvClient:
79
  def __init__(self, url):
80
  self.url = url
 
81
 
82
  def reset(self, task_id):
83
+ return requests.post(
84
+ f"{self.url}/reset",
85
+ json={"task_id": task_id},
86
+ timeout=8
87
  ).json()
88
 
89
  def step(self, action):
90
+ return requests.post(
91
+ f"{self.url}/step",
92
+ json={"action": action},
93
+ timeout=8
94
  ).json()
95
 
96
  def safe_step(client, action):
97
+ for _ in range(3):
98
  try:
99
  return client.step(action)
100
+ except:
101
+ time.sleep(0.5)
102
+ return {"reward": -0.3}
103
+
104
+ # =========================
105
+ # JSON PARSER
106
+ # =========================
107
 
108
  def extract_json(text):
109
  try:
 
112
  if text.startswith("json"):
113
  text = text[4:]
114
  return json.loads(text.strip())
115
+ except:
116
  return None
117
 
118
  # =========================
119
+ # DATASET
120
  # =========================
121
 
122
+ BASE_SCENARIOS = [
123
+ {
124
+ "task_id": "task_1_healthcare",
125
+ "text": "Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.",
126
+ "actions_already_taken": [],
127
+ "setup_actions": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  },
129
+ {
130
+ "task_id": "task_2_financial",
131
+ "text": "Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.",
132
+ "actions_already_taken": [],
133
+ "setup_actions": [],
 
134
  },
135
+ {
136
+ "task_id": "task_3_multimodal",
137
+ "text": "Multimodal ad: image may contain hidden violation. No actions taken yet.",
138
+ "actions_already_taken": [],
139
+ "setup_actions": [],
 
140
  },
141
+ {
142
+ "task_id": "task_1_healthcare",
143
+ "text": "Healthcare ad: pharma product. Policy already queried.",
144
+ "actions_already_taken": ["query_regulations"],
145
+ "setup_actions": [
146
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
147
+ ],
148
  },
149
+ {
150
+ "task_id": "task_3_multimodal",
151
+ "text": "Multimodal ad: image not yet inspected. Policy already queried.",
152
+ "actions_already_taken": ["query_regulations"],
153
+ "setup_actions": [
154
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
155
+ ],
156
  },
157
+ {
158
+ "task_id": "task_2_financial",
159
+ "text": "Financial ad: investment scheme. Policy and advertiser history both checked.",
160
+ "actions_already_taken": ["query_regulations", "check_advertiser_history"],
161
+ "setup_actions": [
162
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
163
+ {"action_type": "check_advertiser_history", "reasoning": "trust score"},
164
+ ],
165
  },
166
+ {
167
+ "task_id": "task_2_financial",
168
+ "text": "Financial ad: investment scheme. Policy, history, and audit all complete. Make final decision.",
169
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
170
+ "setup_actions": [
171
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
172
+ {"action_type": "check_advertiser_history", "reasoning": "trust score"},
173
+ {"action_type": "submit_audit", "reasoning": "audit log"},
174
+ ],
175
  },
176
+ {
177
+ "task_id": "task_4_targeting",
178
+ "text": "Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.",
179
+ "actions_already_taken": [],
180
+ "setup_actions": [],
 
181
  },
182
+ {
183
+ "task_id": "task_4_targeting",
184
+ "text": "Financial ad targeting young users. Policy queried, need to verify age targeting.",
185
+ "actions_already_taken": ["query_regulations"],
186
+ "setup_actions": [
187
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
188
+ ],
189
  },
190
+ {
191
+ "task_id": "task_4_targeting",
192
+ "text": "Financial ad targeting minors. Policy, advertiser history, and ID verification done. Submit audit.",
193
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_id_verification"],
194
+ "setup_actions": [
195
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
196
+ {"action_type": "check_advertiser_history", "reasoning": "trust score"},
197
+ {"action_type": "request_id_verification", "reasoning": "age check"},
198
+ ],
199
  },
200
+ ]
201
 
202
+ PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.
 
 
203
 
204
+ You MUST choose exactly ONE action_type from this list (any other value is invalid):
205
+ - query_regulations
206
+ - analyze_image
207
+ - check_advertiser_history
208
+ - request_landing_page
209
+ - request_id_verification
210
+ - submit_audit
211
+ - approve
212
+ - reject
213
 
214
+ REQUIRED PHASE ORDER:
215
+ 1. query_regulations -> always first
216
+ 2. analyze_image / check_advertiser_history -> gather signals
217
+ 3. submit_audit -> always before final decision
218
+ 4. approve OR reject -> only after audit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ HARD RULES:
221
+ - NEVER repeat an action listed in `actions_already_taken`.
222
+ - Respond with ONLY a valid JSON object. No markdown, no prose.
223
+
224
+ Required format:
225
+ {{"action_type": "<one_of_the_actions_above>", "reasoning": "<short reason>"}}
226
 
227
+ Scenario: {text}
228
+ actions_already_taken: {actions_already_taken}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ Your next action?"""
231
 
232
  def build_dataset():
233
  rows = []
234
  for s in BASE_SCENARIOS:
235
+ prompt = PROMPT_TEMPLATE.format(
236
+ text=s["text"],
237
+ actions_already_taken=json.dumps(s["actions_already_taken"]),
 
 
238
  )
239
  rows.append({
240
+ "prompt": prompt,
 
 
 
241
  "task_id": s["task_id"],
242
  "setup_actions": s["setup_actions"],
243
  })
244
+ return Dataset.from_list(rows * 10) # 10 scenarios x 10 = 100 examples
245
 
246
  # =========================
247
  # REWARD FUNCTION
248
  # =========================
249
 
250
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
 
251
  client = EnvClient(ENV_URL)
252
  rewards = []
253
 
254
+ for completion, t_id, setup in zip(completions, task_id, setup_actions):
 
 
 
255
  parsed = extract_json(completion)
256
  if not parsed:
257
  rewards.append(-1.0)
 
268
  }
269
 
270
  try:
 
271
  client.reset(t_id)
272
  for s in setup:
273
  safe_step(client, s)
 
287
  else:
288
  shaped = 0.5 + env_reward
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  rewards.append(shaped)
291
 
292
  except Exception:
 
298
  # MODEL
299
  # =========================
300
 
 
 
 
 
 
 
 
 
 
 
 
301
  model, tokenizer = FastLanguageModel.from_pretrained(
302
+ model_name="unsloth/Llama-3.1-8B-Instruct",
303
+ load_in_4bit=True, # Strictly True for L4 24GB
304
  max_seq_length=2048,
305
+ dtype=torch.bfloat16, # L4 Native Support
306
  )
307
 
308
  model = FastLanguageModel.get_peft_model(
 
330
  reward_funcs=[reward_environment],
331
  args=GRPOConfig(
332
  output_dir=OUTPUT_DIR,
333
+ learning_rate=2e-5,
334
  num_train_epochs=3,
335
+ per_device_train_batch_size=1, # Memory safe for L4
336
+ gradient_accumulation_steps=8, # Maintain effective batch size of 8
337
+ num_generations=2, # Memory safe generation limit
338
+ max_prompt_length=768,
339
+ max_completion_length=128,
340
  logging_steps=5,
341
+ warmup_ratio=0.1,
342
  bf16=True,
343
  fp16=False,
344
  report_to="none",
 
361
  try:
362
  trainer.train()
363
  except torch.cuda.OutOfMemoryError:
364
+ print("OOM detected! Clearing cache and severely restricting memory...")
365
  torch.cuda.empty_cache()
366
+ trainer.args.per_device_train_batch_size = 1
367
+ trainer.args.gradient_accumulation_steps = 16
368
  trainer.train()
369
 
370
  model.save_pretrained(LORA_DIR)
371
  tokenizer.save_pretrained(LORA_DIR)
372
  print(f"LoRA adapter saved to {LORA_DIR}")
373
 
374
+ print("Merging adapter into base model (bf16)...")
375
  merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
376
  model_name=LORA_DIR,
377
  load_in_4bit=False,