Spaces:
Runtime error
Runtime error
File size: 5,811 Bytes
193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c a5a9c5a aea0b8c 193a9d2 a5a9c5a aea0b8c a5a9c5a aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 a5a9c5a aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c 193a9d2 aea0b8c a5a9c5a aea0b8c 70c7c72 aea0b8c 193a9d2 aea0b8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import json
import torch
import requests
from datasets import Dataset
from unsloth import FastLanguageModel, PatchFastRL
from trl import GRPOTrainer, GRPOConfig
# MUST be called before trainer instantiation
PatchFastRL("GRPO", FastLanguageModel)
ENV_URL = "http://localhost:8000"
TASKS = ["task_1_healthcare", "task_2_financial",
"task_3_multimodal", "task_4_targeting"]
SYSTEM_PROMPT = """You are an enterprise Ad Policy Compliance Agent.
Always respond with ONLY valid JSON, no markdown.
REQUIRED PHASE ORDER:
1. query_regulations β always first
2. analyze_image β required for multimodal tasks
3. submit_audit β always before final decision
4. approve or reject β only after audit
Format: {"action_type": "<action>", "reasoning": "<reason>"}"""
# ββ DATASET βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_dataset():
rows = []
for task_id in TASKS:
res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
obs = res.json()
prompt = (
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
f"{SYSTEM_PROMPT}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n"
f"Task: {task_id}\n"
f"Ad: {obs.get('headline','N/A')} β {obs.get('body_text','N/A')}\n"
f"Trust Score: {obs.get('advertiser_trust_score','N/A')}\n"
f"Status: {obs.get('status_message','')}\n"
f"What is your next action?"
f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
rows.append({"prompt": prompt, "task_id": task_id})
# 25x repetition = 100 rows, enough for 1 epoch
return Dataset.from_list(rows * 25)
# ββ REWARD FUNCTION (actually calls the environment) ββββββββββββββββββββββββββ
def reward_environment(prompts, completions, task_id, **kwargs):
"""
This is the real reward β model outputs an action,
we send it to the environment, environment returns the reward.
"""
rewards = []
# Notice we zip with task_id (from the dataset) and use t_id inside the loop
for completion, t_id in zip(completions, task_id):
try:
# Parse model output
content = completion.strip()
if content.startswith("```"):
content = content.split("```")[1]
if content.startswith("json"):
content = content[4:]
action = json.loads(content.strip())
action_type = action.get("action_type", "query_regulations")
except Exception:
# Malformed JSON = penalty
rewards.append(-0.5)
continue
try:
# Fresh episode for each reward calculation
requests.post(f"{ENV_URL}/reset", json={"task_id": t_id})
# Run a minimal sequence: if model says query_regulations,
# run that then check what reward it generates
step_res = requests.post(
f"{ENV_URL}/step",
json={"action": {"action_type": action_type,
"reasoning": action.get("reasoning", "")}},
timeout=5
)
data = step_res.json()
rewards.append(float(data.get("reward", -0.1)))
except Exception:
rewards.append(-0.1)
return rewards
def reward_json_format(prompts, completions, **kwargs):
"""Bonus reward for valid JSON output."""
rewards = []
for completion in completions:
try:
content = completion.strip()
if content.startswith("```"):
content = content.split("```")[1]
if content.startswith("json"):
content = content[4:]
json.loads(content.strip())
rewards.append(0.5)
except Exception:
rewards.append(-0.5)
return rewards
# ββ MODEL SETUP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.1-8B-Instruct",
max_seq_length=1024,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.0,
use_gradient_checkpointing="unsloth",
)
# ββ TRAINER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
dataset = build_dataset()
trainer = GRPOTrainer(
model=model,
reward_funcs=[reward_environment, reward_json_format],
args=GRPOConfig(
output_dir="outputs/meta-ad-agent",
learning_rate=5e-6,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
max_prompt_length=512,
max_completion_length=128,
num_generations=4, # lower = faster, enough for demo
logging_steps=5,
save_steps=50,
report_to="none",
),
train_dataset=dataset,
tokenizer=tokenizer,
)
if __name__ == "__main__":
print("Starting GRPO training β environment must be running on :8000")
trainer.train()
model.save_pretrained("outputs/meta-ad-agent-final")
tokenizer.save_pretrained("outputs/meta-ad-agent-final")
print("Done. Model saved to outputs/meta-ad-agent-final") |