Kartik Goyal commited on
Commit
b9daf1b
·
1 Parent(s): 47fa380

logic update 2.0

Browse files
Files changed (1) hide show
  1. grpo_train.py +85 -21
grpo_train.py CHANGED
@@ -11,7 +11,6 @@ from datasets import Dataset
11
  from unsloth import FastLanguageModel, PatchFastRL
12
  from trl import GRPOTrainer, GRPOConfig
13
 
14
- # 🔥 MUST come before trainer
15
  PatchFastRL("GRPO", FastLanguageModel)
16
 
17
  # =========================
@@ -19,14 +18,18 @@ PatchFastRL("GRPO", FastLanguageModel)
19
  # =========================
20
 
21
  ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
 
 
22
 
23
  ALLOWED_ACTIONS = [
24
  "query_regulations",
25
  "analyze_image",
26
  "check_advertiser_history",
 
 
27
  "submit_audit",
28
  "approve",
29
- "reject"
30
  ]
31
 
32
  # =========================
@@ -158,6 +161,34 @@ BASE_SCENARIOS = [
158
  {"action_type": "submit_audit", "reasoning": "audit log"},
159
  ],
160
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  ]
162
 
163
  PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.
@@ -166,6 +197,8 @@ You MUST choose exactly ONE action_type from this list (any other value is inval
166
  - query_regulations
167
  - analyze_image
168
  - check_advertiser_history
 
 
169
  - submit_audit
170
  - approve
171
  - reject
@@ -201,7 +234,7 @@ def build_dataset():
201
  "task_id": s["task_id"],
202
  "setup_actions": s["setup_actions"],
203
  })
204
- return Dataset.from_list(rows * 10) # 7 scenarios x 10 = 70 examples
205
 
206
  # =========================
207
  # REWARD FUNCTION (FIXED)
@@ -267,20 +300,23 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
267
  # MODEL
268
  # =========================
269
 
 
 
270
  model, tokenizer = FastLanguageModel.from_pretrained(
271
  model_name="unsloth/Llama-3.1-8B-Instruct",
272
- load_in_4bit=True,
273
- max_seq_length=1024,
 
274
  )
275
 
276
  model = FastLanguageModel.get_peft_model(
277
  model,
278
- r=16,
279
  target_modules=[
280
  "q_proj", "k_proj", "v_proj", "o_proj",
281
  "gate_proj", "up_proj", "down_proj",
282
  ],
283
- lora_alpha=32,
284
  lora_dropout=0,
285
  bias="none",
286
  use_gradient_checkpointing="unsloth",
@@ -298,18 +334,20 @@ trainer = GRPOTrainer(
298
  reward_funcs=[reward_environment],
299
  args=GRPOConfig(
300
  output_dir="outputs",
301
- learning_rate=5e-6,
302
- num_train_epochs=1,
303
- per_device_train_batch_size=1,
304
- gradient_accumulation_steps=2,
305
- num_generations=2,
306
- max_prompt_length=512,
307
- max_completion_length=64,
308
- logging_steps=2,
309
- report_to="none"
 
 
310
  ),
311
  train_dataset=dataset,
312
- tokenizer=tokenizer
313
  )
314
 
315
  # =========================
@@ -319,10 +357,36 @@ trainer = GRPOTrainer(
319
  if __name__ == "__main__":
320
  ensure_env_ready()
321
 
322
- print("🚀 Starting training...")
323
  trainer.train()
324
 
325
- model.save_pretrained("outputs/final")
326
- tokenizer.save_pretrained("outputs/final")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- print("Done")
 
11
  from unsloth import FastLanguageModel, PatchFastRL
12
  from trl import GRPOTrainer, GRPOConfig
13
 
 
14
  PatchFastRL("GRPO", FastLanguageModel)
15
 
16
  # =========================
 
18
  # =========================
19
 
20
  ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
21
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
22
+ HF_REPO = os.getenv("HF_REPO", "") # e.g. "yourname/metaguard-llama3.1-8b-grpo"
23
 
24
  ALLOWED_ACTIONS = [
25
  "query_regulations",
26
  "analyze_image",
27
  "check_advertiser_history",
28
+ "request_landing_page",
29
+ "request_id_verification",
30
  "submit_audit",
31
  "approve",
32
+ "reject",
33
  ]
34
 
35
  # =========================
 
161
  {"action_type": "submit_audit", "reasoning": "audit log"},
162
  ],
163
  },
164
+
165
+ # Targeting task — fresh state, expected: query_regulations
166
+ {
167
+ "task_id": "task_4_targeting",
168
+ "text": "Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.",
169
+ "actions_already_taken": [],
170
+ "setup_actions": [],
171
+ },
172
+ # Targeting task — mid state, expected: request_id_verification (age check)
173
+ {
174
+ "task_id": "task_4_targeting",
175
+ "text": "Financial ad targeting young users. Policy queried, need to verify age targeting.",
176
+ "actions_already_taken": ["query_regulations"],
177
+ "setup_actions": [
178
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
179
+ ],
180
+ },
181
+ # Targeting task — audit ready
182
+ {
183
+ "task_id": "task_4_targeting",
184
+ "text": "Financial ad targeting minors. Policy, advertiser history, and ID verification done. Submit audit.",
185
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_id_verification"],
186
+ "setup_actions": [
187
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
188
+ {"action_type": "check_advertiser_history", "reasoning": "trust score"},
189
+ {"action_type": "request_id_verification", "reasoning": "age check"},
190
+ ],
191
+ },
192
  ]
193
 
194
  PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.
 
197
  - query_regulations
198
  - analyze_image
199
  - check_advertiser_history
200
+ - request_landing_page
201
+ - request_id_verification
202
  - submit_audit
203
  - approve
204
  - reject
 
234
  "task_id": s["task_id"],
235
  "setup_actions": s["setup_actions"],
236
  })
237
+ return Dataset.from_list(rows * 10) # 10 scenarios x 10 = 100 examples
238
 
239
  # =========================
240
  # REWARD FUNCTION (FIXED)
 
300
  # MODEL
301
  # =========================
302
 
303
+ USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_mem < 40 * 1024**3
304
+
305
  model, tokenizer = FastLanguageModel.from_pretrained(
306
  model_name="unsloth/Llama-3.1-8B-Instruct",
307
+ load_in_4bit=USE_4BIT,
308
+ max_seq_length=2048,
309
+ dtype=None, # auto-detect bf16 on A100
310
  )
311
 
312
  model = FastLanguageModel.get_peft_model(
313
  model,
314
+ r=32,
315
  target_modules=[
316
  "q_proj", "k_proj", "v_proj", "o_proj",
317
  "gate_proj", "up_proj", "down_proj",
318
  ],
319
+ lora_alpha=64,
320
  lora_dropout=0,
321
  bias="none",
322
  use_gradient_checkpointing="unsloth",
 
334
  reward_funcs=[reward_environment],
335
  args=GRPOConfig(
336
  output_dir="outputs",
337
+ learning_rate=2e-5,
338
+ num_train_epochs=3,
339
+ per_device_train_batch_size=2,
340
+ gradient_accumulation_steps=4,
341
+ num_generations=4,
342
+ max_prompt_length=768,
343
+ max_completion_length=128,
344
+ logging_steps=5,
345
+ warmup_ratio=0.1,
346
+ bf16=True,
347
+ report_to="none",
348
  ),
349
  train_dataset=dataset,
350
+ tokenizer=tokenizer,
351
  )
352
 
353
  # =========================
 
357
  if __name__ == "__main__":
358
  ensure_env_ready()
359
 
360
+ print("Starting GRPO training...")
361
  trainer.train()
362
 
363
+ model.save_pretrained("outputs/lora_adapter")
364
+ tokenizer.save_pretrained("outputs/lora_adapter")
365
+ print("LoRA adapter saved to outputs/lora_adapter")
366
+
367
+ print("Merging adapter into base model (bf16)...")
368
+ merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
369
+ model_name="outputs/lora_adapter",
370
+ load_in_4bit=False,
371
+ max_seq_length=2048,
372
+ )
373
+ merged_model.save_pretrained_merged(
374
+ "outputs/merged",
375
+ merged_tokenizer,
376
+ save_method="merged_16bit",
377
+ )
378
+ print("Merged model saved to outputs/merged")
379
+
380
+ if HF_REPO:
381
+ print(f"Pushing merged model to {HF_REPO}...")
382
+ merged_model.push_to_hub_merged(
383
+ HF_REPO,
384
+ merged_tokenizer,
385
+ save_method="merged_16bit",
386
+ token=HF_TOKEN,
387
+ )
388
+ print(f"Model live at https://huggingface.co/{HF_REPO}")
389
+ else:
390
+ print("Set HF_REPO env var to auto-push to Hub (skipped).")
391
 
392
+ print("Done.")