round 2 improvement updated GRPO
Browse files- grpo_train.py +132 -37
grpo_train.py
CHANGED
|
@@ -1,56 +1,151 @@
|
|
|
|
|
| 1 |
import torch
|
|
|
|
|
|
|
| 2 |
from unsloth import FastLanguageModel, PatchFastRL
|
| 3 |
from trl import GRPOTrainer, GRPOConfig
|
| 4 |
-
from src.environment import AdPolicyEnvironment
|
| 5 |
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
rewards = []
|
| 16 |
-
for completion in completions:
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
return rewards
|
| 23 |
|
| 24 |
def reward_json_format(prompts, completions, **kwargs):
|
|
|
|
| 25 |
rewards = []
|
| 26 |
for completion in completions:
|
| 27 |
try:
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
return rewards
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
trainer = GRPOTrainer(
|
| 48 |
-
model
|
| 49 |
-
reward_funcs
|
| 50 |
-
args
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
import torch
|
| 3 |
+
import requests
|
| 4 |
+
from datasets import Dataset
|
| 5 |
from unsloth import FastLanguageModel, PatchFastRL
|
| 6 |
from trl import GRPOTrainer, GRPOConfig
|
|
|
|
| 7 |
|
| 8 |
+
# MUST be called before trainer instantiation
|
| 9 |
+
PatchFastRL("GRPO", FastLanguageModel)
|
| 10 |
+
|
| 11 |
+
ENV_URL = "http://localhost:8000"
|
| 12 |
+
TASKS = ["task_1_healthcare", "task_2_financial",
|
| 13 |
+
"task_3_multimodal", "task_4_targeting"]
|
| 14 |
+
|
| 15 |
+
SYSTEM_PROMPT = """You are an enterprise Ad Policy Compliance Agent.
|
| 16 |
+
Always respond with ONLY valid JSON, no markdown.
|
| 17 |
+
|
| 18 |
+
REQUIRED PHASE ORDER:
|
| 19 |
+
1. query_regulations β always first
|
| 20 |
+
2. analyze_image β required for multimodal tasks
|
| 21 |
+
3. submit_audit β always before final decision
|
| 22 |
+
4. approve or reject β only after audit
|
| 23 |
|
| 24 |
+
Format: {"action_type": "<action>", "reasoning": "<reason>"}"""
|
| 25 |
+
|
| 26 |
+
# ββ DATASET βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
|
| 28 |
+
def build_dataset():
|
| 29 |
+
rows = []
|
| 30 |
+
for task_id in TASKS:
|
| 31 |
+
res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
|
| 32 |
+
obs = res.json()
|
| 33 |
+
prompt = (
|
| 34 |
+
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
|
| 35 |
+
f"{SYSTEM_PROMPT}<|eot_id|>"
|
| 36 |
+
f"<|start_header_id|>user<|end_header_id|>\n"
|
| 37 |
+
f"Task: {task_id}\n"
|
| 38 |
+
f"Ad: {obs.get('headline','N/A')} β {obs.get('body_text','N/A')}\n"
|
| 39 |
+
f"Trust Score: {obs.get('advertiser_trust_score','N/A')}\n"
|
| 40 |
+
f"Status: {obs.get('status_message','')}\n"
|
| 41 |
+
f"What is your next action?"
|
| 42 |
+
f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
|
| 43 |
+
)
|
| 44 |
+
rows.append({"prompt": prompt, "task_id": task_id})
|
| 45 |
+
# 25x repetition = 100 rows, enough for 1 epoch
|
| 46 |
+
return Dataset.from_list(rows * 25)
|
| 47 |
+
|
| 48 |
+
# ββ REWARD FUNCTION (actually calls the environment) ββββββββββββββββββββββββββ
|
| 49 |
+
|
| 50 |
+
def reward_environment(prompts, completions, task_ids, **kwargs):
|
| 51 |
+
"""
|
| 52 |
+
This is the real reward β model outputs an action,
|
| 53 |
+
we send it to the environment, environment returns the reward.
|
| 54 |
+
"""
|
| 55 |
rewards = []
|
| 56 |
+
for completion, task_id in zip(completions, task_ids):
|
| 57 |
+
try:
|
| 58 |
+
# Parse model output
|
| 59 |
+
content = completion.strip()
|
| 60 |
+
if content.startswith("```"):
|
| 61 |
+
content = content.split("```")[1]
|
| 62 |
+
if content.startswith("json"):
|
| 63 |
+
content = content[4:]
|
| 64 |
+
action = json.loads(content.strip())
|
| 65 |
+
action_type = action.get("action_type", "query_regulations")
|
| 66 |
+
except Exception:
|
| 67 |
+
# Malformed JSON = penalty
|
| 68 |
+
rewards.append(-0.5)
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
# Fresh episode for each reward calculation
|
| 73 |
+
requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
|
| 74 |
+
|
| 75 |
+
# Run a minimal sequence: if model says query_regulations,
|
| 76 |
+
# run that then check what reward it generates
|
| 77 |
+
step_res = requests.post(
|
| 78 |
+
f"{ENV_URL}/step",
|
| 79 |
+
json={"action": {"action_type": action_type,
|
| 80 |
+
"reasoning": action.get("reasoning", "")}},
|
| 81 |
+
timeout=5
|
| 82 |
+
)
|
| 83 |
+
data = step_res.json()
|
| 84 |
+
rewards.append(float(data.get("reward", -0.1)))
|
| 85 |
+
except Exception:
|
| 86 |
+
rewards.append(-0.1)
|
| 87 |
+
|
| 88 |
return rewards
|
| 89 |
|
| 90 |
def reward_json_format(prompts, completions, **kwargs):
|
| 91 |
+
"""Bonus reward for valid JSON output."""
|
| 92 |
rewards = []
|
| 93 |
for completion in completions:
|
| 94 |
try:
|
| 95 |
+
content = completion.strip()
|
| 96 |
+
if content.startswith("```"):
|
| 97 |
+
content = content.split("```")[1]
|
| 98 |
+
if content.startswith("json"):
|
| 99 |
+
content = content[4:]
|
| 100 |
+
json.loads(content.strip())
|
| 101 |
+
rewards.append(0.5)
|
| 102 |
+
except Exception:
|
| 103 |
+
rewards.append(-0.5)
|
| 104 |
return rewards
|
| 105 |
|
| 106 |
+
# ββ MODEL SETUP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
|
| 108 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 109 |
+
model_name="unsloth/Llama-3.1-8B-Instruct",
|
| 110 |
+
max_seq_length=1024,
|
| 111 |
+
load_in_4bit=True,
|
| 112 |
+
)
|
| 113 |
+
model = FastLanguageModel.get_peft_model(
|
| 114 |
+
model,
|
| 115 |
+
r=16,
|
| 116 |
+
target_modules=["q_proj", "v_proj"],
|
| 117 |
+
lora_alpha=16,
|
| 118 |
+
lora_dropout=0.0,
|
| 119 |
+
use_gradient_checkpointing="unsloth",
|
| 120 |
)
|
| 121 |
|
| 122 |
+
# ββ TRAINER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
+
|
| 124 |
+
dataset = build_dataset()
|
| 125 |
+
|
| 126 |
trainer = GRPOTrainer(
|
| 127 |
+
model=model,
|
| 128 |
+
reward_funcs=[reward_environment, reward_json_format],
|
| 129 |
+
args=GRPOConfig(
|
| 130 |
+
output_dir="outputs/meta-ad-agent",
|
| 131 |
+
learning_rate=5e-6,
|
| 132 |
+
num_train_epochs=1,
|
| 133 |
+
per_device_train_batch_size=2,
|
| 134 |
+
gradient_accumulation_steps=4,
|
| 135 |
+
max_prompt_length=512,
|
| 136 |
+
max_completion_length=128,
|
| 137 |
+
num_generations=4, # lower = faster, enough for demo
|
| 138 |
+
logging_steps=5,
|
| 139 |
+
save_steps=50,
|
| 140 |
+
report_to="none",
|
| 141 |
+
),
|
| 142 |
+
train_dataset=dataset,
|
| 143 |
+
tokenizer=tokenizer,
|
| 144 |
)
|
| 145 |
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
print("Starting GRPO training β environment must be running on :8000")
|
| 148 |
+
trainer.train()
|
| 149 |
+
model.save_pretrained("outputs/meta-ad-agent-final")
|
| 150 |
+
tokenizer.save_pretrained("outputs/meta-ad-agent-final")
|
| 151 |
+
print("Done. Model saved to outputs/meta-ad-agent-final")
|