File size: 5,811 Bytes
193a9d2
 
aea0b8c
193a9d2
 
 
 
aea0b8c
193a9d2
 
aea0b8c
 
 
a5a9c5a
aea0b8c
 
193a9d2
a5a9c5a
aea0b8c
 
 
 
a5a9c5a
aea0b8c
193a9d2
aea0b8c
193a9d2
 
 
aea0b8c
 
 
 
 
 
 
 
 
 
 
 
 
193a9d2
aea0b8c
 
 
193a9d2
aea0b8c
193a9d2
aea0b8c
 
 
 
 
193a9d2
aea0b8c
 
 
 
 
 
 
 
 
 
 
 
 
 
193a9d2
 
 
aea0b8c
 
 
 
 
 
 
 
 
 
193a9d2
aea0b8c
 
 
 
193a9d2
aea0b8c
193a9d2
aea0b8c
 
 
 
 
 
 
 
 
 
 
 
193a9d2
aea0b8c
193a9d2
 
aea0b8c
193a9d2
 
a5a9c5a
aea0b8c
 
193a9d2
 
 
aea0b8c
 
 
 
193a9d2
 
 
aea0b8c
193a9d2
 
 
 
 
aea0b8c
193a9d2
aea0b8c
 
 
 
 
 
a5a9c5a
aea0b8c
70c7c72
aea0b8c
193a9d2
 
 
 
 
 
 
aea0b8c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import torch
import requests
from datasets import Dataset
from unsloth import FastLanguageModel, PatchFastRL
from trl import GRPOTrainer, GRPOConfig

# MUST be called before trainer instantiation
PatchFastRL("GRPO", FastLanguageModel)

ENV_URL = "http://localhost:8000"
TASKS = ["task_1_healthcare", "task_2_financial",
         "task_3_multimodal", "task_4_targeting"]

SYSTEM_PROMPT = """You are an enterprise Ad Policy Compliance Agent.
Always respond with ONLY valid JSON, no markdown.

REQUIRED PHASE ORDER:
1. query_regulations  β€” always first
2. analyze_image      β€” required for multimodal tasks  
3. submit_audit       β€” always before final decision
4. approve or reject  β€” only after audit

Format: {"action_type": "<action>", "reasoning": "<reason>"}"""

# ── DATASET ───────────────────────────────────────────────────────────────────

def build_dataset():
    rows = []
    for task_id in TASKS:
        res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
        obs = res.json()
        prompt = (
            f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
            f"{SYSTEM_PROMPT}<|eot_id|>"
            f"<|start_header_id|>user<|end_header_id|>\n"
            f"Task: {task_id}\n"
            f"Ad: {obs.get('headline','N/A')} β€” {obs.get('body_text','N/A')}\n"
            f"Trust Score: {obs.get('advertiser_trust_score','N/A')}\n"
            f"Status: {obs.get('status_message','')}\n"
            f"What is your next action?"
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        )
        rows.append({"prompt": prompt, "task_id": task_id})
    # 25x repetition = 100 rows, enough for 1 epoch
    return Dataset.from_list(rows * 25)

# ── REWARD FUNCTION (actually calls the environment) ──────────────────────────

def reward_environment(prompts, completions, task_id, **kwargs):
    """
    This is the real reward β€” model outputs an action,
    we send it to the environment, environment returns the reward.
    """
    rewards = []
    # Notice we zip with task_id (from the dataset) and use t_id inside the loop
    for completion, t_id in zip(completions, task_id):
        try:
            # Parse model output
            content = completion.strip()
            if content.startswith("```"):
                content = content.split("```")[1]
                if content.startswith("json"):
                    content = content[4:]
            action = json.loads(content.strip())
            action_type = action.get("action_type", "query_regulations")
        except Exception:
            # Malformed JSON = penalty
            rewards.append(-0.5)
            continue

        try:
            # Fresh episode for each reward calculation
            requests.post(f"{ENV_URL}/reset", json={"task_id": t_id})
            
            # Run a minimal sequence: if model says query_regulations,
            # run that then check what reward it generates
            step_res = requests.post(
                f"{ENV_URL}/step",
                json={"action": {"action_type": action_type, 
                                 "reasoning": action.get("reasoning", "")}},
                timeout=5
            )
            data = step_res.json()
            rewards.append(float(data.get("reward", -0.1)))
        except Exception:
            rewards.append(-0.1)

    return rewards

def reward_json_format(prompts, completions, **kwargs):
    """Bonus reward for valid JSON output."""
    rewards = []
    for completion in completions:
        try:
            content = completion.strip()
            if content.startswith("```"):
                content = content.split("```")[1]
                if content.startswith("json"):
                    content = content[4:]
            json.loads(content.strip())
            rewards.append(0.5)
        except Exception:
            rewards.append(-0.5)
    return rewards

# ── MODEL SETUP ───────────────────────────────────────────────────────────────

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Llama-3.1-8B-Instruct",
    max_seq_length=1024,
    load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "v_proj"],
    lora_alpha=16,
    lora_dropout=0.0,
    use_gradient_checkpointing="unsloth",
)

# ── TRAINER ───────────────────────────────────────────────────────────────────

dataset = build_dataset()

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[reward_environment, reward_json_format],
    args=GRPOConfig(
        output_dir="outputs/meta-ad-agent",
        learning_rate=5e-6,
        num_train_epochs=1,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        max_prompt_length=512,
        max_completion_length=128,
        num_generations=4,          # lower = faster, enough for demo
        logging_steps=5,
        save_steps=50,
        report_to="none",
    ),
    train_dataset=dataset,
    tokenizer=tokenizer,
)

if __name__ == "__main__":
    print("Starting GRPO training β€” environment must be running on :8000")
    trainer.train()
    model.save_pretrained("outputs/meta-ad-agent-final")
    tokenizer.save_pretrained("outputs/meta-ad-agent-final")
    print("Done. Model saved to outputs/meta-ad-agent-final")