3v324v23 commited on
Commit
7c3bc96
Β·
1 Parent(s): c6123cd

round 2 improvement updated GRPO

Browse files
Files changed (1) hide show
  1. grpo_train.py +132 -37
grpo_train.py CHANGED
@@ -1,56 +1,151 @@
 
1
  import torch
 
 
2
  from unsloth import FastLanguageModel, PatchFastRL
3
  from trl import GRPOTrainer, GRPOConfig
4
- from src.environment import AdPolicyEnvironment
5
 
6
- # 1. Load Model with Unsloth
7
- model, tokenizer = FastLanguageModel.from_pretrained(
8
- model_name = "unsloth/Llama-3.1-8B-Instruct",
9
- max_seq_length = 1024,
10
- load_in_4bit = True,
11
- )
 
 
 
 
 
 
 
 
 
12
 
13
- # 2. Define Reward Functions
14
- def reward_compliance(prompts, completions, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  rewards = []
16
- for completion in completions:
17
- # Check if the model called the necessary tools in order
18
- if "query_regulations" in completion and "submit_audit" in completion:
19
- rewards.append(2.0)
20
- else:
21
- rewards.append(0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  return rewards
23
 
24
  def reward_json_format(prompts, completions, **kwargs):
 
25
  rewards = []
26
  for completion in completions:
27
  try:
28
- import json
29
- json.loads(completion)
30
- rewards.append(1.0)
31
- except:
32
- rewards.append(0.0)
 
 
 
 
33
  return rewards
34
 
35
- # 3. Configure Trainer
36
- training_args = GRPOConfig(
37
- output_dir = "outputs/meta-ad-agent",
38
- learning_rate = 5e-6,
39
- num_train_epochs = 1,
40
- per_device_train_batch_size = 4,
41
- gradient_accumulation_steps = 4,
42
- max_prompt_length = 512,
43
- max_completion_length = 512,
44
- num_generations = 8, # Number of variations to compare
 
 
 
 
45
  )
46
 
 
 
 
 
47
  trainer = GRPOTrainer(
48
- model = model,
49
- reward_funcs = [reward_compliance, reward_json_format],
50
- args = training_args,
51
- train_dataset = [], # We will stream data from your AdGenerator here
52
- tokenizer = tokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
- # 4. Start Training
56
- # trainer.train()
 
 
 
 
 
1
+ import json
2
  import torch
3
+ import requests
4
+ from datasets import Dataset
5
  from unsloth import FastLanguageModel, PatchFastRL
6
  from trl import GRPOTrainer, GRPOConfig
 
7
 
8
+ # MUST be called before trainer instantiation
9
+ PatchFastRL("GRPO", FastLanguageModel)
10
+
11
+ ENV_URL = "http://localhost:8000"
12
+ TASKS = ["task_1_healthcare", "task_2_financial",
13
+ "task_3_multimodal", "task_4_targeting"]
14
+
15
+ SYSTEM_PROMPT = """You are an enterprise Ad Policy Compliance Agent.
16
+ Always respond with ONLY valid JSON, no markdown.
17
+
18
+ REQUIRED PHASE ORDER:
19
+ 1. query_regulations β€” always first
20
+ 2. analyze_image β€” required for multimodal tasks
21
+ 3. submit_audit β€” always before final decision
22
+ 4. approve or reject β€” only after audit
23
 
24
+ Format: {"action_type": "<action>", "reasoning": "<reason>"}"""
25
+
26
+ # ── DATASET ───────────────────────────────────────────────────────────────────
27
+
28
+ def build_dataset():
29
+ rows = []
30
+ for task_id in TASKS:
31
+ res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
32
+ obs = res.json()
33
+ prompt = (
34
+ f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
35
+ f"{SYSTEM_PROMPT}<|eot_id|>"
36
+ f"<|start_header_id|>user<|end_header_id|>\n"
37
+ f"Task: {task_id}\n"
38
+ f"Ad: {obs.get('headline','N/A')} β€” {obs.get('body_text','N/A')}\n"
39
+ f"Trust Score: {obs.get('advertiser_trust_score','N/A')}\n"
40
+ f"Status: {obs.get('status_message','')}\n"
41
+ f"What is your next action?"
42
+ f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
43
+ )
44
+ rows.append({"prompt": prompt, "task_id": task_id})
45
+ # 25x repetition = 100 rows, enough for 1 epoch
46
+ return Dataset.from_list(rows * 25)
47
+
48
+ # ── REWARD FUNCTION (actually calls the environment) ──────────────────────────
49
+
50
+ def reward_environment(prompts, completions, task_ids, **kwargs):
51
+ """
52
+ This is the real reward β€” model outputs an action,
53
+ we send it to the environment, environment returns the reward.
54
+ """
55
  rewards = []
56
+ for completion, task_id in zip(completions, task_ids):
57
+ try:
58
+ # Parse model output
59
+ content = completion.strip()
60
+ if content.startswith("```"):
61
+ content = content.split("```")[1]
62
+ if content.startswith("json"):
63
+ content = content[4:]
64
+ action = json.loads(content.strip())
65
+ action_type = action.get("action_type", "query_regulations")
66
+ except Exception:
67
+ # Malformed JSON = penalty
68
+ rewards.append(-0.5)
69
+ continue
70
+
71
+ try:
72
+ # Fresh episode for each reward calculation
73
+ requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
74
+
75
+ # Run a minimal sequence: if model says query_regulations,
76
+ # run that then check what reward it generates
77
+ step_res = requests.post(
78
+ f"{ENV_URL}/step",
79
+ json={"action": {"action_type": action_type,
80
+ "reasoning": action.get("reasoning", "")}},
81
+ timeout=5
82
+ )
83
+ data = step_res.json()
84
+ rewards.append(float(data.get("reward", -0.1)))
85
+ except Exception:
86
+ rewards.append(-0.1)
87
+
88
  return rewards
89
 
90
  def reward_json_format(prompts, completions, **kwargs):
91
+ """Bonus reward for valid JSON output."""
92
  rewards = []
93
  for completion in completions:
94
  try:
95
+ content = completion.strip()
96
+ if content.startswith("```"):
97
+ content = content.split("```")[1]
98
+ if content.startswith("json"):
99
+ content = content[4:]
100
+ json.loads(content.strip())
101
+ rewards.append(0.5)
102
+ except Exception:
103
+ rewards.append(-0.5)
104
  return rewards
105
 
106
+ # ── MODEL SETUP ───────────────────────────────────────────────────────────────
107
+
108
+ model, tokenizer = FastLanguageModel.from_pretrained(
109
+ model_name="unsloth/Llama-3.1-8B-Instruct",
110
+ max_seq_length=1024,
111
+ load_in_4bit=True,
112
+ )
113
+ model = FastLanguageModel.get_peft_model(
114
+ model,
115
+ r=16,
116
+ target_modules=["q_proj", "v_proj"],
117
+ lora_alpha=16,
118
+ lora_dropout=0.0,
119
+ use_gradient_checkpointing="unsloth",
120
  )
121
 
122
+ # ── TRAINER ───────────────────────────────────────────────────────────────────
123
+
124
+ dataset = build_dataset()
125
+
126
  trainer = GRPOTrainer(
127
+ model=model,
128
+ reward_funcs=[reward_environment, reward_json_format],
129
+ args=GRPOConfig(
130
+ output_dir="outputs/meta-ad-agent",
131
+ learning_rate=5e-6,
132
+ num_train_epochs=1,
133
+ per_device_train_batch_size=2,
134
+ gradient_accumulation_steps=4,
135
+ max_prompt_length=512,
136
+ max_completion_length=128,
137
+ num_generations=4, # lower = faster, enough for demo
138
+ logging_steps=5,
139
+ save_steps=50,
140
+ report_to="none",
141
+ ),
142
+ train_dataset=dataset,
143
+ tokenizer=tokenizer,
144
  )
145
 
146
+ if __name__ == "__main__":
147
+ print("Starting GRPO training β€” environment must be running on :8000")
148
+ trainer.train()
149
+ model.save_pretrained("outputs/meta-ad-agent-final")
150
+ tokenizer.save_pretrained("outputs/meta-ad-agent-final")
151
+ print("Done. Model saved to outputs/meta-ad-agent-final")