parth-1 commited on
Commit
aea0b8c
Β·
verified Β·
1 Parent(s): 2961503

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +99 -350
grpo_train.py CHANGED
@@ -1,403 +1,152 @@
1
- # grpo_train.py
2
-
3
- import os
4
-
5
- # Route all caches to /tmp/ to avoid Hugging Face Spaces Read-Only errors
6
- os.environ.setdefault("USER", "user")
7
- os.environ.setdefault("HOME", "/tmp/home")
8
- os.environ.setdefault("HF_HOME", "/tmp/hf_home")
9
- os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_home/transformers")
10
- os.environ.setdefault("TRITON_CACHE_DIR", "/tmp/triton_cache")
11
- os.environ.setdefault("TORCH_EXTENSIONS_DIR", "/tmp/torch_ext")
12
- os.environ.setdefault("XDG_CACHE_HOME", "/tmp/cache")
13
- os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
14
-
15
- for _d in [
16
- "/tmp/home", "/tmp/hf_home", "/tmp/hf_home/transformers",
17
- "/tmp/triton_cache", "/tmp/torch_ext", "/tmp/cache",
18
- "/tmp/mpl", "/tmp/outputs",
19
- ]:
20
- os.makedirs(_d, exist_ok=True)
21
-
22
- import time
23
  import json
24
- import random
25
- import requests
26
  import torch
27
-
28
  from datasets import Dataset
29
  from unsloth import FastLanguageModel, PatchFastRL
30
  from trl import GRPOTrainer, GRPOConfig
31
 
 
32
  PatchFastRL("GRPO", FastLanguageModel)
33
 
34
- # =========================
35
- # CONFIG
36
- # =========================
37
-
38
- OUTPUT_DIR = "/tmp/outputs"
39
- ENV_URL = os.getenv("ENV_URL", "https://parth-1-metaguard.hf.space")
40
- HF_TOKEN = os.getenv("HF_TOKEN", "")
41
- HF_REPO = os.getenv("HF_REPO", "") # e.g. "parth-1/metaguard-llama3.1-8b-grpo"
42
-
43
- ALLOWED_ACTIONS = [
44
- "query_regulations",
45
- "analyze_image",
46
- "check_advertiser_history",
47
- "request_landing_page",
48
- "request_id_verification",
49
- "submit_audit",
50
- "approve",
51
- "reject",
52
- ]
53
-
54
- # =========================
55
- # HEALTH CHECK
56
- # =========================
57
-
58
- def ensure_env_ready():
59
- for _ in range(20):
60
- try:
61
- r = requests.post(
62
- f"{ENV_URL}/reset",
63
- json={"task_id": "task_1_healthcare"},
64
- timeout=5
65
- )
66
- if r.status_code == 200:
67
- print("βœ… Environment ready")
68
- return
69
- except:
70
- pass
71
- time.sleep(1)
72
- raise RuntimeError("❌ ENV not reachable")
73
-
74
- # =========================
75
- # SAFE CLIENT
76
- # =========================
77
-
78
- class EnvClient:
79
- def __init__(self, url):
80
- self.url = url
81
-
82
- def reset(self, task_id):
83
- return requests.post(
84
- f"{self.url}/reset",
85
- json={"task_id": task_id},
86
- timeout=8
87
- ).json()
88
-
89
- def step(self, action):
90
- return requests.post(
91
- f"{self.url}/step",
92
- json={"action": action},
93
- timeout=8
94
- ).json()
95
-
96
- def safe_step(client, action):
97
- for _ in range(3):
98
- try:
99
- return client.step(action)
100
- except:
101
- time.sleep(0.5)
102
- return {"reward": -0.3}
103
 
104
- # =========================
105
- # JSON PARSER
106
- # =========================
107
-
108
- def extract_json(text):
109
- try:
110
- if "```" in text:
111
- text = text.split("```")[1]
112
- if text.startswith("json"):
113
- text = text[4:]
114
- return json.loads(text.strip())
115
- except:
116
- return None
117
-
118
- # =========================
119
- # DATASET
120
- # =========================
121
-
122
- BASE_SCENARIOS = [
123
- {
124
- "task_id": "task_1_healthcare",
125
- "text": "Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.",
126
- "actions_already_taken": [],
127
- "setup_actions": [],
128
- },
129
- {
130
- "task_id": "task_2_financial",
131
- "text": "Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.",
132
- "actions_already_taken": [],
133
- "setup_actions": [],
134
- },
135
- {
136
- "task_id": "task_3_multimodal",
137
- "text": "Multimodal ad: image may contain hidden violation. No actions taken yet.",
138
- "actions_already_taken": [],
139
- "setup_actions": [],
140
- },
141
- {
142
- "task_id": "task_1_healthcare",
143
- "text": "Healthcare ad: pharma product. Policy already queried.",
144
- "actions_already_taken": ["query_regulations"],
145
- "setup_actions": [
146
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
147
- ],
148
- },
149
- {
150
- "task_id": "task_3_multimodal",
151
- "text": "Multimodal ad: image not yet inspected. Policy already queried.",
152
- "actions_already_taken": ["query_regulations"],
153
- "setup_actions": [
154
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
155
- ],
156
- },
157
- {
158
- "task_id": "task_2_financial",
159
- "text": "Financial ad: investment scheme. Policy and advertiser history both checked.",
160
- "actions_already_taken": ["query_regulations", "check_advertiser_history"],
161
- "setup_actions": [
162
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
163
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
164
- ],
165
- },
166
- {
167
- "task_id": "task_2_financial",
168
- "text": "Financial ad: investment scheme. Policy, history, and audit all complete. Make final decision.",
169
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
170
- "setup_actions": [
171
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
172
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
173
- {"action_type": "submit_audit", "reasoning": "audit log"},
174
- ],
175
- },
176
- {
177
- "task_id": "task_4_targeting",
178
- "text": "Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.",
179
- "actions_already_taken": [],
180
- "setup_actions": [],
181
- },
182
- {
183
- "task_id": "task_4_targeting",
184
- "text": "Financial ad targeting young users. Policy queried, need to verify age targeting.",
185
- "actions_already_taken": ["query_regulations"],
186
- "setup_actions": [
187
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
188
- ],
189
- },
190
- {
191
- "task_id": "task_4_targeting",
192
- "text": "Financial ad targeting minors. Policy, advertiser history, and ID verification done. Submit audit.",
193
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_id_verification"],
194
- "setup_actions": [
195
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
196
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
197
- {"action_type": "request_id_verification", "reasoning": "age check"},
198
- ],
199
- },
200
- ]
201
-
202
- PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.
203
-
204
- You MUST choose exactly ONE action_type from this list (any other value is invalid):
205
- - query_regulations
206
- - analyze_image
207
- - check_advertiser_history
208
- - request_landing_page
209
- - request_id_verification
210
- - submit_audit
211
- - approve
212
- - reject
213
 
214
  REQUIRED PHASE ORDER:
215
- 1. query_regulations -> always first
216
- 2. analyze_image / check_advertiser_history -> gather signals
217
- 3. submit_audit -> always before final decision
218
- 4. approve OR reject -> only after audit
219
-
220
- HARD RULES:
221
- - NEVER repeat an action listed in `actions_already_taken`.
222
- - Respond with ONLY a valid JSON object. No markdown, no prose.
223
 
224
- Required format:
225
- {{"action_type": "<one_of_the_actions_above>", "reasoning": "<short reason>"}}
226
 
227
- Scenario: {text}
228
- actions_already_taken: {actions_already_taken}
229
-
230
- Your next action?"""
231
 
232
  def build_dataset():
233
  rows = []
234
- for s in BASE_SCENARIOS:
235
- prompt = PROMPT_TEMPLATE.format(
236
- text=s["text"],
237
- actions_already_taken=json.dumps(s["actions_already_taken"]),
 
 
 
 
 
 
 
 
 
238
  )
239
- rows.append({
240
- "prompt": prompt,
241
- "task_id": s["task_id"],
242
- "setup_actions": s["setup_actions"],
243
- })
244
- return Dataset.from_list(rows * 10) # 10 scenarios x 10 = 100 examples
245
 
246
- # =========================
247
- # REWARD FUNCTION
248
- # =========================
249
 
250
- def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
251
- client = EnvClient(ENV_URL)
 
 
 
252
  rewards = []
253
-
254
- for completion, t_id, setup in zip(completions, task_id, setup_actions):
255
- parsed = extract_json(completion)
256
- if not parsed:
257
- rewards.append(-1.0)
258
- continue
259
-
260
- action_type = parsed.get("action_type")
261
- if action_type not in ALLOWED_ACTIONS:
262
- rewards.append(-1.0)
 
 
 
 
263
  continue
264
 
265
- action = {
266
- "action_type": action_type,
267
- "reasoning": parsed.get("reasoning", "format-compliant"),
268
- }
269
-
270
  try:
271
- client.reset(t_id)
272
- for s in setup:
273
- safe_step(client, s)
274
-
275
- result = safe_step(client, action)
276
- env_reward = float(result.get("reward", -0.2))
277
- status_msg = (result.get("status_message") or "").lower()
278
-
279
- rejected = (
280
- "api failure" in status_msg
281
- or "invalid action" in status_msg
282
- or "must call" in status_msg
283
  )
 
 
 
 
284
 
285
- if rejected:
286
- shaped = -0.5
287
- else:
288
- shaped = 0.5 + env_reward
289
-
290
- rewards.append(shaped)
291
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  except Exception:
293
- rewards.append(-0.3)
294
-
295
  return rewards
296
 
297
- # =========================
298
- # MODEL
299
- # =========================
300
 
301
  model, tokenizer = FastLanguageModel.from_pretrained(
302
  model_name="unsloth/Llama-3.1-8B-Instruct",
303
- load_in_4bit=True, # Strictly True for L4 24GB
304
- max_seq_length=2048,
305
- dtype=torch.float16, # PERFECT ALIGNMENT: 4-bit uses fp16 math natively
306
  )
307
-
308
  model = FastLanguageModel.get_peft_model(
309
  model,
310
- r=32,
311
- target_modules=[
312
- "q_proj", "k_proj", "v_proj", "o_proj",
313
- "gate_proj", "up_proj", "down_proj",
314
- ],
315
- lora_alpha=64,
316
- lora_dropout=0,
317
- bias="none",
318
  use_gradient_checkpointing="unsloth",
319
- random_state=3407,
320
  )
321
 
322
- # =========================
323
- # TRAINER
324
- # =========================
325
 
326
  dataset = build_dataset()
327
 
328
  trainer = GRPOTrainer(
329
  model=model,
330
- reward_funcs=[reward_environment],
331
  args=GRPOConfig(
332
- output_dir=OUTPUT_DIR,
333
- learning_rate=2e-5,
334
- num_train_epochs=3,
335
- per_device_train_batch_size=1, # Memory safe for L4
336
- gradient_accumulation_steps=8, # Maintain effective batch size of 8
337
- num_generations=2, # Memory safe generation limit
338
- max_prompt_length=768,
339
  max_completion_length=128,
 
340
  logging_steps=5,
341
- warmup_ratio=0.1,
342
- bf16=False, # DISABLED TO PREVENT CLASH
343
- fp16=True, # ENABLED TO MATCH MODEL DTYPE
344
  report_to="none",
345
  ),
346
  train_dataset=dataset,
347
  tokenizer=tokenizer,
348
  )
349
 
350
- # =========================
351
- # RUN
352
- # =========================
353
-
354
  if __name__ == "__main__":
355
- ensure_env_ready()
356
-
357
- LORA_DIR = os.path.join(OUTPUT_DIR, "lora_adapter")
358
- MERGED_DIR = os.path.join(OUTPUT_DIR, "merged")
359
-
360
- print("Starting GRPO training...")
361
- try:
362
- trainer.train()
363
- except torch.cuda.OutOfMemoryError:
364
- print("OOM detected! Clearing cache and severely restricting memory...")
365
- torch.cuda.empty_cache()
366
- trainer.args.per_device_train_batch_size = 1
367
- trainer.args.gradient_accumulation_steps = 16
368
- trainer.train()
369
-
370
- model.save_pretrained(LORA_DIR)
371
- tokenizer.save_pretrained(LORA_DIR)
372
- print(f"LoRA adapter saved to {LORA_DIR}")
373
-
374
- print("Merging adapter into base model (fp16)...")
375
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
376
- model_name=LORA_DIR,
377
- load_in_4bit=False,
378
- max_seq_length=2048,
379
- )
380
- merged_model.save_pretrained_merged(
381
- MERGED_DIR,
382
- merged_tokenizer,
383
- save_method="merged_16bit",
384
- )
385
- print(f"Merged model saved to {MERGED_DIR}")
386
-
387
- if HF_REPO:
388
- try:
389
- print(f"Pushing merged model to {HF_REPO}...")
390
- merged_model.push_to_hub_merged(
391
- HF_REPO,
392
- merged_tokenizer,
393
- save_method="merged_16bit",
394
- token=HF_TOKEN,
395
- )
396
- print(f"Model live at https://huggingface.co/{HF_REPO}")
397
- except Exception as e:
398
- print(f"Hub push failed: {e}")
399
- print(f"Model is still saved locally at {MERGED_DIR}")
400
- else:
401
- print("Set HF_REPO env var to auto-push to Hub (skipped).")
402
-
403
- print("Done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
 
 
2
  import torch
3
+ import requests
4
  from datasets import Dataset
5
  from unsloth import FastLanguageModel, PatchFastRL
6
  from trl import GRPOTrainer, GRPOConfig
7
 
8
+ # MUST be called before trainer instantiation
9
  PatchFastRL("GRPO", FastLanguageModel)
10
 
11
+ ENV_URL = "http://localhost:8000"
12
+ TASKS = ["task_1_healthcare", "task_2_financial",
13
+ "task_3_multimodal", "task_4_targeting"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ SYSTEM_PROMPT = """You are an enterprise Ad Policy Compliance Agent.
16
+ Always respond with ONLY valid JSON, no markdown.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  REQUIRED PHASE ORDER:
19
+ 1. query_regulations β€” always first
20
+ 2. analyze_image β€” required for multimodal tasks
21
+ 3. submit_audit β€” always before final decision
22
+ 4. approve or reject β€” only after audit
 
 
 
 
23
 
24
+ Format: {"action_type": "<action>", "reasoning": "<reason>"}"""
 
25
 
26
+ # ── DATASET ───────────────────────────────────────────────────────────────────
 
 
 
27
 
28
  def build_dataset():
29
  rows = []
30
+ for task_id in TASKS:
31
+ res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
32
+ obs = res.json()
33
+ prompt = (
34
+ f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
35
+ f"{SYSTEM_PROMPT}<|eot_id|>"
36
+ f"<|start_header_id|>user<|end_header_id|>\n"
37
+ f"Task: {task_id}\n"
38
+ f"Ad: {obs.get('headline','N/A')} β€” {obs.get('body_text','N/A')}\n"
39
+ f"Trust Score: {obs.get('advertiser_trust_score','N/A')}\n"
40
+ f"Status: {obs.get('status_message','')}\n"
41
+ f"What is your next action?"
42
+ f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
43
  )
44
+ rows.append({"prompt": prompt, "task_id": task_id})
45
+ # 25x repetition = 100 rows, enough for 1 epoch
46
+ return Dataset.from_list(rows * 25)
 
 
 
47
 
48
+ # ── REWARD FUNCTION (actually calls the environment) ──────────────────────────
 
 
49
 
50
+ def reward_environment(prompts, completions, task_id, **kwargs):
51
+ """
52
+ This is the real reward β€” model outputs an action,
53
+ we send it to the environment, environment returns the reward.
54
+ """
55
  rewards = []
56
+ # Notice we zip with task_id (from the dataset) and use t_id inside the loop
57
+ for completion, t_id in zip(completions, task_id):
58
+ try:
59
+ # Parse model output
60
+ content = completion.strip()
61
+ if content.startswith("```"):
62
+ content = content.split("```")[1]
63
+ if content.startswith("json"):
64
+ content = content[4:]
65
+ action = json.loads(content.strip())
66
+ action_type = action.get("action_type", "query_regulations")
67
+ except Exception:
68
+ # Malformed JSON = penalty
69
+ rewards.append(-0.5)
70
  continue
71
 
 
 
 
 
 
72
  try:
73
+ # Fresh episode for each reward calculation
74
+ requests.post(f"{ENV_URL}/reset", json={"task_id": t_id})
75
+
76
+ # Run a minimal sequence: if model says query_regulations,
77
+ # run that then check what reward it generates
78
+ step_res = requests.post(
79
+ f"{ENV_URL}/step",
80
+ json={"action": {"action_type": action_type,
81
+ "reasoning": action.get("reasoning", "")}},
82
+ timeout=5
 
 
83
  )
84
+ data = step_res.json()
85
+ rewards.append(float(data.get("reward", -0.1)))
86
+ except Exception:
87
+ rewards.append(-0.1)
88
 
89
+ return rewards
 
 
 
 
 
90
 
91
+ def reward_json_format(prompts, completions, **kwargs):
92
+ """Bonus reward for valid JSON output."""
93
+ rewards = []
94
+ for completion in completions:
95
+ try:
96
+ content = completion.strip()
97
+ if content.startswith("```"):
98
+ content = content.split("```")[1]
99
+ if content.startswith("json"):
100
+ content = content[4:]
101
+ json.loads(content.strip())
102
+ rewards.append(0.5)
103
  except Exception:
104
+ rewards.append(-0.5)
 
105
  return rewards
106
 
107
+ # ── MODEL SETUP ───────────────────────────────────────────────────────────────
 
 
108
 
109
  model, tokenizer = FastLanguageModel.from_pretrained(
110
  model_name="unsloth/Llama-3.1-8B-Instruct",
111
+ max_seq_length=1024,
112
+ load_in_4bit=True,
 
113
  )
 
114
  model = FastLanguageModel.get_peft_model(
115
  model,
116
+ r=16,
117
+ target_modules=["q_proj", "v_proj"],
118
+ lora_alpha=16,
119
+ lora_dropout=0.0,
 
 
 
 
120
  use_gradient_checkpointing="unsloth",
 
121
  )
122
 
123
+ # ── TRAINER ───────────────────────────────────────────────────────────────────
 
 
124
 
125
  dataset = build_dataset()
126
 
127
  trainer = GRPOTrainer(
128
  model=model,
129
+ reward_funcs=[reward_environment, reward_json_format],
130
  args=GRPOConfig(
131
+ output_dir="outputs/meta-ad-agent",
132
+ learning_rate=5e-6,
133
+ num_train_epochs=1,
134
+ per_device_train_batch_size=2,
135
+ gradient_accumulation_steps=4,
136
+ max_prompt_length=512,
 
137
  max_completion_length=128,
138
+ num_generations=4, # lower = faster, enough for demo
139
  logging_steps=5,
140
+ save_steps=50,
 
 
141
  report_to="none",
142
  ),
143
  train_dataset=dataset,
144
  tokenizer=tokenizer,
145
  )
146
 
 
 
 
 
147
  if __name__ == "__main__":
148
+ print("Starting GRPO training β€” environment must be running on :8000")
149
+ trainer.train()
150
+ model.save_pretrained("outputs/meta-ad-agent-final")
151
+ tokenizer.save_pretrained("outputs/meta-ad-agent-final")
152
+ print("Done. Model saved to outputs/meta-ad-agent-final")