File size: 14,937 Bytes
daa0358
 
 
 
7c3bc96
daa0358
7c3bc96
daa0358
 
7c3bc96
8a685c0
 
 
7c3bc96
 
88c89bd
 
 
 
 
 
 
 
 
 
daa0358
 
 
7c3bc96
daa0358
b9daf1b
 
7c3bc96
daa0358
 
 
 
b9daf1b
 
daa0358
 
b9daf1b
daa0358
8a685c0
daa0358
 
 
7c3bc96
daa0358
88c89bd
 
 
 
daa0358
 
 
 
 
 
 
88c89bd
 
 
daa0358
 
88c89bd
 
 
 
daa0358
 
88c89bd
 
 
daa0358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47fa380
daa0358
 
47fa380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0358
 
47fa380
daa0358
 
47fa380
 
daa0358
47fa380
 
 
 
 
 
 
 
 
 
daa0358
 
47fa380
daa0358
 
47fa380
 
daa0358
47fa380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9daf1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0358
7c3bc96
47fa380
daa0358
47fa380
 
 
 
b9daf1b
 
47fa380
 
 
daa0358
47fa380
 
 
 
 
daa0358
47fa380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0358
 
 
47fa380
daa0358
b9daf1b
daa0358
 
 
 
 
88c89bd
 
daa0358
88c89bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0358
8a685c0
daa0358
88c89bd
 
 
 
 
 
 
daa0358
88c89bd
 
 
daa0358
 
7c3bc96
 
daa0358
 
 
 
 
 
 
47fa380
daa0358
8a685c0
 
daa0358
 
 
 
 
47fa380
 
daa0358
47fa380
 
 
 
 
daa0358
47fa380
 
 
 
 
 
 
 
daa0358
 
8a685c0
 
daa0358
 
 
7c3bc96
88c89bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9daf1b
7c3bc96
 
b9daf1b
 
88c89bd
7c3bc96
daa0358
7c3bc96
 
88c89bd
47fa380
 
 
 
88c89bd
daa0358
47fa380
 
 
8a685c0
 
daa0358
 
 
7c3bc96
 
 
88c89bd
 
 
 
8a685c0
7c3bc96
daa0358
7c3bc96
daa0358
b9daf1b
88c89bd
 
 
 
b9daf1b
 
88c89bd
 
 
 
b9daf1b
7c3bc96
 
b9daf1b
8a685c0
 
daa0358
 
 
 
7c3bc96
daa0358
 
88c89bd
 
 
b9daf1b
7c3bc96
daa0358
b9daf1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0358
b9daf1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
# grpo_train.py

import os
import time
import json
import random
import requests
import torch

from datasets import Dataset
from unsloth import FastLanguageModel, PatchFastRL
from trl import GRPOTrainer, GRPOConfig

PatchFastRL("GRPO", FastLanguageModel)

# #region agent log
import pathlib as _pl
_DLOG = _pl.Path("debug-851b5f.log")
def _dlog(hyp, loc, msg, data=None):
    import time as _t
    entry = json.dumps({"sessionId":"851b5f","hypothesisId":hyp,"location":loc,"message":msg,"data":data or {},"timestamp":int(_t.time()*1000)})
    with open(_DLOG, "a") as f: f.write(entry + "\n")
    print(f"[DBG:{hyp}] {msg} {data or ''}", flush=True)
# #endregion

# =========================
# CONFIG
# =========================

ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
HF_TOKEN = os.getenv("HF_TOKEN", "")
HF_REPO = os.getenv("HF_REPO", "")  # e.g. "yourname/metaguard-llama3.1-8b-grpo"

ALLOWED_ACTIONS = [
    "query_regulations",
    "analyze_image",
    "check_advertiser_history",
    "request_landing_page",
    "request_id_verification",
    "submit_audit",
    "approve",
    "reject",
]

# =========================
# HEALTH CHECK
# =========================

def ensure_env_ready():
    # #region agent log
    _dlog("B", "grpo_train.py:ensure_env_ready", "Checking env", {"ENV_URL": ENV_URL})
    # #endregion
    for i in range(20):
        try:
            r = requests.post(
                f"{ENV_URL}/reset",
                json={"task_id": "task_1_healthcare"},
                timeout=5
            )
            if r.status_code == 200:
                # #region agent log
                _dlog("B", "grpo_train.py:ensure_env_ready", "Env ready", {"attempt": i+1, "status": r.status_code})
                # #endregion
                print("βœ… Environment ready")
                return
        except Exception as e:
            # #region agent log
            if i == 0: _dlog("B", "grpo_train.py:ensure_env_ready", "Env connection failed", {"attempt": i+1, "error": str(e)[:200]})
            # #endregion
            pass
        time.sleep(1)
    # #region agent log
    _dlog("B", "grpo_train.py:ensure_env_ready", "ENV UNREACHABLE after 20 attempts", {})
    # #endregion
    raise RuntimeError("❌ ENV not reachable")

# =========================
# SAFE CLIENT
# =========================

class EnvClient:
    def __init__(self, url):
        self.url = url

    def reset(self, task_id):
        return requests.post(
            f"{self.url}/reset",
            json={"task_id": task_id},
            timeout=8
        ).json()

    def step(self, action):
        return requests.post(
            f"{self.url}/step",
            json={"action": action},
            timeout=8
        ).json()

def safe_step(client, action):
    for _ in range(3):
        try:
            return client.step(action)
        except:
            time.sleep(0.5)
    return {"reward": -0.3}

# =========================
# JSON PARSER
# =========================

def extract_json(text):
    try:
        if "```" in text:
            text = text.split("```")[1]
            if text.startswith("json"):
                text = text[4:]
        return json.loads(text.strip())
    except:
        return None

# =========================
# DATASET (WITH SETUP ACTIONS)
# =========================

BASE_SCENARIOS = [
    # Phase 1 β€” Fresh state, expected: query_regulations
    {
        "task_id": "task_1_healthcare",
        "text": "Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.",
        "actions_already_taken": [],
        "setup_actions": [],
    },
    {
        "task_id": "task_2_financial",
        "text": "Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.",
        "actions_already_taken": [],
        "setup_actions": [],
    },
    {
        "task_id": "task_3_multimodal",
        "text": "Multimodal ad: image may contain hidden violation. No actions taken yet.",
        "actions_already_taken": [],
        "setup_actions": [],
    },

    # Phase 2 β€” Policy checked, expected: analyze_image OR check_advertiser_history
    {
        "task_id": "task_1_healthcare",
        "text": "Healthcare ad: pharma product. Policy already queried.",
        "actions_already_taken": ["query_regulations"],
        "setup_actions": [
            {"action_type": "query_regulations", "reasoning": "policy lookup"},
        ],
    },
    {
        "task_id": "task_3_multimodal",
        "text": "Multimodal ad: image not yet inspected. Policy already queried.",
        "actions_already_taken": ["query_regulations"],
        "setup_actions": [
            {"action_type": "query_regulations", "reasoning": "policy lookup"},
        ],
    },

    # Phase 3 β€” Policy + history checked, expected: submit_audit
    {
        "task_id": "task_2_financial",
        "text": "Financial ad: investment scheme. Policy and advertiser history both checked.",
        "actions_already_taken": ["query_regulations", "check_advertiser_history"],
        "setup_actions": [
            {"action_type": "query_regulations", "reasoning": "policy lookup"},
            {"action_type": "check_advertiser_history", "reasoning": "trust score"},
        ],
    },

    # Phase 4 β€” Audit complete, expected: reject (high-risk) or approve (clean)
    {
        "task_id": "task_2_financial",
        "text": "Financial ad: investment scheme. Policy, history, and audit all complete. Make final decision.",
        "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
        "setup_actions": [
            {"action_type": "query_regulations", "reasoning": "policy lookup"},
            {"action_type": "check_advertiser_history", "reasoning": "trust score"},
            {"action_type": "submit_audit", "reasoning": "audit log"},
        ],
    },

    # Targeting task β€” fresh state, expected: query_regulations
    {
        "task_id": "task_4_targeting",
        "text": "Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.",
        "actions_already_taken": [],
        "setup_actions": [],
    },
    # Targeting task β€” mid state, expected: request_id_verification (age check)
    {
        "task_id": "task_4_targeting",
        "text": "Financial ad targeting young users. Policy queried, need to verify age targeting.",
        "actions_already_taken": ["query_regulations"],
        "setup_actions": [
            {"action_type": "query_regulations", "reasoning": "policy lookup"},
        ],
    },
    # Targeting task β€” audit ready
    {
        "task_id": "task_4_targeting",
        "text": "Financial ad targeting minors. Policy, advertiser history, and ID verification done. Submit audit.",
        "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_id_verification"],
        "setup_actions": [
            {"action_type": "query_regulations", "reasoning": "policy lookup"},
            {"action_type": "check_advertiser_history", "reasoning": "trust score"},
            {"action_type": "request_id_verification", "reasoning": "age check"},
        ],
    },
]

PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.

You MUST choose exactly ONE action_type from this list (any other value is invalid):
- query_regulations
- analyze_image
- check_advertiser_history
- request_landing_page
- request_id_verification
- submit_audit
- approve
- reject

REQUIRED PHASE ORDER:
1. query_regulations  -> always first
2. analyze_image / check_advertiser_history  -> gather signals
3. submit_audit  -> always before final decision
4. approve OR reject  -> only after audit

HARD RULES:
- NEVER repeat an action listed in `actions_already_taken`.
- Respond with ONLY a valid JSON object. No markdown, no prose.

Required format:
{{"action_type": "<one_of_the_actions_above>", "reasoning": "<short reason>"}}

Scenario: {text}
actions_already_taken: {actions_already_taken}

Your next action?"""


def build_dataset():
    rows = []
    for s in BASE_SCENARIOS:
        prompt = PROMPT_TEMPLATE.format(
            text=s["text"],
            actions_already_taken=json.dumps(s["actions_already_taken"]),
        )
        rows.append({
            "prompt": prompt,
            "task_id": s["task_id"],
            "setup_actions": s["setup_actions"],
        })
    return Dataset.from_list(rows * 10)  # 10 scenarios x 10 = 100 examples

# =========================
# REWARD FUNCTION (FIXED)
# =========================

_reward_call_count = [0]

def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
    """Shaped reward for GRPO."""
    _reward_call_count[0] += 1
    _call = _reward_call_count[0]
    # #region agent log
    _dlog("C", "grpo_train.py:reward_env", f"reward call #{_call}", {
        "n_prompts": len(prompts) if prompts else 0,
        "n_completions": len(completions) if completions else 0,
        "completions_type": type(completions).__name__,
        "first_completion_type": type(completions[0]).__name__ if completions else "N/A",
        "first_completion_preview": str(completions[0])[:150] if completions else "N/A",
        "task_id_is_none": task_id is None,
        "setup_actions_is_none": setup_actions is None,
        "kwargs_keys": list(kwargs.keys()),
    })
    # #endregion

    client = EnvClient(ENV_URL)
    rewards = []

    if task_id is None or setup_actions is None:
        # #region agent log
        _dlog("D", "grpo_train.py:reward_env", "task_id or setup_actions is None β€” returning -1 for all", {"call": _call})
        # #endregion
        return [-1.0] * len(completions)

    for idx, (completion, t_id, setup) in enumerate(zip(completions, task_id, setup_actions)):
        parsed = extract_json(completion)
        # #region agent log
        if _call <= 3: _dlog("D", "grpo_train.py:reward_loop", f"call#{_call} item#{idx}", {"parsed_ok": parsed is not None, "action": parsed.get("action_type") if parsed else None, "raw_preview": str(completion)[:120], "task_id": t_id})
        # #endregion
        if not parsed:
            rewards.append(-1.0)
            continue

        action_type = parsed.get("action_type")
        if action_type not in ALLOWED_ACTIONS:
            rewards.append(-1.0)
            continue

        action = {
            "action_type": action_type,
            "reasoning": parsed.get("reasoning", "format-compliant"),
        }

        try:
            client.reset(t_id)
            for s in setup:
                safe_step(client, s)

            result = safe_step(client, action)
            env_reward = float(result.get("reward", -0.2))
            status_msg = (result.get("status_message") or "").lower()

            rejected = (
                "api failure" in status_msg
                or "invalid action" in status_msg
                or "must call" in status_msg
            )

            if rejected:
                shaped = -0.5
            else:
                shaped = 0.5 + env_reward

            rewards.append(shaped)

        except Exception:
            rewards.append(-0.3)

    return rewards

# =========================
# MODEL
# =========================

if torch.cuda.is_available():
    _props = torch.cuda.get_device_properties(0)
    _vram = _props.total_memory
    _name = _props.name
    _cc = (_props.major, _props.minor)  # compute capability
    print(f"GPU: {_name}  VRAM: {_vram / 1024**3:.1f} GB  Compute: {_cc[0]}.{_cc[1]}")
else:
    _vram = 0
    _name = "CPU"
    _cc = (0, 0)

USE_4BIT = _vram < 40 * 1024**3   # T4 (15 GB), L4 (24 GB) β†’ 4-bit; A100 (80 GB) β†’ full
USE_BF16 = _cc >= (8, 0) and not USE_4BIT  # bf16 only when full-precision; 4-bit LoRA uses fp16 internally

# #region agent log
_dlog("A", "grpo_train.py:gpu_detect", "GPU config resolved", {"name":_name,"vram_gb":round(_vram/1024**3,1),"cc":list(_cc),"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16})
# #endregion

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Llama-3.1-8B-Instruct",
    load_in_4bit=USE_4BIT,
    max_seq_length=2048,
    dtype=torch.float16 if USE_4BIT else None,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16 if USE_4BIT else 32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=32 if USE_4BIT else 64,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

# =========================
# TRAINER
# =========================

dataset = build_dataset()

# #region agent log
_dlog("A", "grpo_train.py:trainer_init", "Creating GRPOTrainer", {"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16,"epochs":1 if USE_4BIT else 3,"batch":1 if USE_4BIT else 2,"gens":2 if USE_4BIT else 4,"dataset_len":len(dataset)})
# #endregion

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[reward_environment],
    args=GRPOConfig(
        output_dir="outputs",
        learning_rate=2e-5,
        num_train_epochs=1 if USE_4BIT else 3,
        per_device_train_batch_size=1 if USE_4BIT else 2,
        gradient_accumulation_steps=2 if USE_4BIT else 4,
        num_generations=2 if USE_4BIT else 4,
        max_prompt_length=768,
        max_completion_length=128,
        logging_steps=3 if USE_4BIT else 5,
        warmup_steps=5 if USE_4BIT else 10,
        bf16=USE_BF16,
        fp16=not USE_BF16,
        report_to="none",
    ),
    train_dataset=dataset,
    tokenizer=tokenizer,
)

# =========================
# RUN
# =========================

if __name__ == "__main__":
    ensure_env_ready()

    # #region agent log
    _dlog("E", "grpo_train.py:train_start", "About to call trainer.train()", {"gpu_mem_allocated_gb": round(torch.cuda.memory_allocated()/1024**3, 2) if torch.cuda.is_available() else 0})
    # #endregion
    print("Starting GRPO training...")
    trainer.train()

    model.save_pretrained("outputs/lora_adapter")
    tokenizer.save_pretrained("outputs/lora_adapter")
    print("LoRA adapter saved to outputs/lora_adapter")

    print("Merging adapter into base model (bf16)...")
    merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
        model_name="outputs/lora_adapter",
        load_in_4bit=False,
        max_seq_length=2048,
    )
    merged_model.save_pretrained_merged(
        "outputs/merged",
        merged_tokenizer,
        save_method="merged_16bit",
    )
    print("Merged model saved to outputs/merged")

    if HF_REPO:
        print(f"Pushing merged model to {HF_REPO}...")
        merged_model.push_to_hub_merged(
            HF_REPO,
            merged_tokenizer,
            save_method="merged_16bit",
            token=HF_TOKEN,
        )
        print(f"Model live at https://huggingface.co/{HF_REPO}")
    else:
        print("Set HF_REPO env var to auto-push to Hub (skipped).")

    print("Done.")