parth-1 commited on
Commit
193a9d2
·
verified ·
1 Parent(s): fbc6056

fixed fp16 bfp16 clash

Browse files
Files changed (1) hide show
  1. grpo_train.py +392 -391
grpo_train.py CHANGED
@@ -1,392 +1,393 @@
1
- # grpo_train.py
2
-
3
- import os
4
- import time
5
- import json
6
- import random
7
- import requests
8
- import torch
9
-
10
- from datasets import Dataset
11
- from unsloth import FastLanguageModel, PatchFastRL
12
- from trl import GRPOTrainer, GRPOConfig
13
-
14
- PatchFastRL("GRPO", FastLanguageModel)
15
-
16
- # =========================
17
- # CONFIG
18
- # =========================
19
-
20
- ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
21
- HF_TOKEN = os.getenv("HF_TOKEN", "")
22
- HF_REPO = os.getenv("HF_REPO", "") # e.g. "yourname/metaguard-llama3.1-8b-grpo"
23
-
24
- ALLOWED_ACTIONS = [
25
- "query_regulations",
26
- "analyze_image",
27
- "check_advertiser_history",
28
- "request_landing_page",
29
- "request_id_verification",
30
- "submit_audit",
31
- "approve",
32
- "reject",
33
- ]
34
-
35
- # =========================
36
- # HEALTH CHECK
37
- # =========================
38
-
39
- def ensure_env_ready():
40
- for _ in range(20):
41
- try:
42
- r = requests.post(
43
- f"{ENV_URL}/reset",
44
- json={"task_id": "task_1_healthcare"},
45
- timeout=5
46
- )
47
- if r.status_code == 200:
48
- print("✅ Environment ready")
49
- return
50
- except:
51
- pass
52
- time.sleep(1)
53
- raise RuntimeError("❌ ENV not reachable")
54
-
55
- # =========================
56
- # SAFE CLIENT
57
- # =========================
58
-
59
- class EnvClient:
60
- def __init__(self, url):
61
- self.url = url
62
-
63
- def reset(self, task_id):
64
- return requests.post(
65
- f"{self.url}/reset",
66
- json={"task_id": task_id},
67
- timeout=8
68
- ).json()
69
-
70
- def step(self, action):
71
- return requests.post(
72
- f"{self.url}/step",
73
- json={"action": action},
74
- timeout=8
75
- ).json()
76
-
77
- def safe_step(client, action):
78
- for _ in range(3):
79
- try:
80
- return client.step(action)
81
- except:
82
- time.sleep(0.5)
83
- return {"reward": -0.3}
84
-
85
- # =========================
86
- # JSON PARSER
87
- # =========================
88
-
89
- def extract_json(text):
90
- try:
91
- if "```" in text:
92
- text = text.split("```")[1]
93
- if text.startswith("json"):
94
- text = text[4:]
95
- return json.loads(text.strip())
96
- except:
97
- return None
98
-
99
- # =========================
100
- # DATASET (WITH SETUP ACTIONS)
101
- # =========================
102
-
103
- BASE_SCENARIOS = [
104
- # Phase 1 — Fresh state, expected: query_regulations
105
- {
106
- "task_id": "task_1_healthcare",
107
- "text": "Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.",
108
- "actions_already_taken": [],
109
- "setup_actions": [],
110
- },
111
- {
112
- "task_id": "task_2_financial",
113
- "text": "Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.",
114
- "actions_already_taken": [],
115
- "setup_actions": [],
116
- },
117
- {
118
- "task_id": "task_3_multimodal",
119
- "text": "Multimodal ad: image may contain hidden violation. No actions taken yet.",
120
- "actions_already_taken": [],
121
- "setup_actions": [],
122
- },
123
-
124
- # Phase 2 — Policy checked, expected: analyze_image OR check_advertiser_history
125
- {
126
- "task_id": "task_1_healthcare",
127
- "text": "Healthcare ad: pharma product. Policy already queried.",
128
- "actions_already_taken": ["query_regulations"],
129
- "setup_actions": [
130
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
131
- ],
132
- },
133
- {
134
- "task_id": "task_3_multimodal",
135
- "text": "Multimodal ad: image not yet inspected. Policy already queried.",
136
- "actions_already_taken": ["query_regulations"],
137
- "setup_actions": [
138
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
139
- ],
140
- },
141
-
142
- # Phase 3 — Policy + history checked, expected: submit_audit
143
- {
144
- "task_id": "task_2_financial",
145
- "text": "Financial ad: investment scheme. Policy and advertiser history both checked.",
146
- "actions_already_taken": ["query_regulations", "check_advertiser_history"],
147
- "setup_actions": [
148
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
149
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
150
- ],
151
- },
152
-
153
- # Phase 4 — Audit complete, expected: reject (high-risk) or approve (clean)
154
- {
155
- "task_id": "task_2_financial",
156
- "text": "Financial ad: investment scheme. Policy, history, and audit all complete. Make final decision.",
157
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
158
- "setup_actions": [
159
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
160
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
161
- {"action_type": "submit_audit", "reasoning": "audit log"},
162
- ],
163
- },
164
-
165
- # Targeting task — fresh state, expected: query_regulations
166
- {
167
- "task_id": "task_4_targeting",
168
- "text": "Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.",
169
- "actions_already_taken": [],
170
- "setup_actions": [],
171
- },
172
- # Targeting task — mid state, expected: request_id_verification (age check)
173
- {
174
- "task_id": "task_4_targeting",
175
- "text": "Financial ad targeting young users. Policy queried, need to verify age targeting.",
176
- "actions_already_taken": ["query_regulations"],
177
- "setup_actions": [
178
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
179
- ],
180
- },
181
- # Targeting task — audit ready
182
- {
183
- "task_id": "task_4_targeting",
184
- "text": "Financial ad targeting minors. Policy, advertiser history, and ID verification done. Submit audit.",
185
- "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_id_verification"],
186
- "setup_actions": [
187
- {"action_type": "query_regulations", "reasoning": "policy lookup"},
188
- {"action_type": "check_advertiser_history", "reasoning": "trust score"},
189
- {"action_type": "request_id_verification", "reasoning": "age check"},
190
- ],
191
- },
192
- ]
193
-
194
- PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.
195
-
196
- You MUST choose exactly ONE action_type from this list (any other value is invalid):
197
- - query_regulations
198
- - analyze_image
199
- - check_advertiser_history
200
- - request_landing_page
201
- - request_id_verification
202
- - submit_audit
203
- - approve
204
- - reject
205
-
206
- REQUIRED PHASE ORDER:
207
- 1. query_regulations -> always first
208
- 2. analyze_image / check_advertiser_history -> gather signals
209
- 3. submit_audit -> always before final decision
210
- 4. approve OR reject -> only after audit
211
-
212
- HARD RULES:
213
- - NEVER repeat an action listed in `actions_already_taken`.
214
- - Respond with ONLY a valid JSON object. No markdown, no prose.
215
-
216
- Required format:
217
- {{"action_type": "<one_of_the_actions_above>", "reasoning": "<short reason>"}}
218
-
219
- Scenario: {text}
220
- actions_already_taken: {actions_already_taken}
221
-
222
- Your next action?"""
223
-
224
-
225
- def build_dataset():
226
- rows = []
227
- for s in BASE_SCENARIOS:
228
- prompt = PROMPT_TEMPLATE.format(
229
- text=s["text"],
230
- actions_already_taken=json.dumps(s["actions_already_taken"]),
231
- )
232
- rows.append({
233
- "prompt": prompt,
234
- "task_id": s["task_id"],
235
- "setup_actions": s["setup_actions"],
236
- })
237
- return Dataset.from_list(rows * 10) # 10 scenarios x 10 = 100 examples
238
-
239
- # =========================
240
- # REWARD FUNCTION (FIXED)
241
- # =========================
242
-
243
- def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
244
- """Shaped reward for GRPO.
245
-
246
- Pure env reward is too sparse (mostly -0.05) to give clear gradients.
247
- We add explicit shaping:
248
- - invalid JSON / invalid action_type -> -1.0 (strong negative signal)
249
- - valid action env REJECTS (wrong phase / API failure) -> -0.5
250
- - valid action env ACCEPTS (advances state) -> +0.5 + env_reward
251
- - terminal correct decision -> env_reward already contains +1.0 bonus
252
- """
253
- client = EnvClient(ENV_URL)
254
- rewards = []
255
-
256
- for completion, t_id, setup in zip(completions, task_id, setup_actions):
257
- parsed = extract_json(completion)
258
- if not parsed:
259
- rewards.append(-1.0)
260
- continue
261
-
262
- action_type = parsed.get("action_type")
263
- if action_type not in ALLOWED_ACTIONS:
264
- rewards.append(-1.0)
265
- continue
266
-
267
- action = {
268
- "action_type": action_type,
269
- "reasoning": parsed.get("reasoning", "format-compliant"),
270
- }
271
-
272
- try:
273
- client.reset(t_id)
274
- for s in setup:
275
- safe_step(client, s)
276
-
277
- result = safe_step(client, action)
278
- env_reward = float(result.get("reward", -0.2))
279
- status_msg = (result.get("status_message") or "").lower()
280
-
281
- rejected = (
282
- "api failure" in status_msg
283
- or "invalid action" in status_msg
284
- or "must call" in status_msg
285
- )
286
-
287
- if rejected:
288
- shaped = -0.5
289
- else:
290
- shaped = 0.5 + env_reward
291
-
292
- rewards.append(shaped)
293
-
294
- except Exception:
295
- rewards.append(-0.3)
296
-
297
- return rewards
298
-
299
- # =========================
300
- # MODEL
301
- # =========================
302
-
303
- USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 40 * 1024**3
304
-
305
- model, tokenizer = FastLanguageModel.from_pretrained(
306
- model_name="unsloth/Llama-3.1-8B-Instruct",
307
- load_in_4bit=USE_4BIT,
308
- max_seq_length=2048,
309
- dtype=None, # auto-detect bf16 on A100
310
- )
311
-
312
- model = FastLanguageModel.get_peft_model(
313
- model,
314
- r=32,
315
- target_modules=[
316
- "q_proj", "k_proj", "v_proj", "o_proj",
317
- "gate_proj", "up_proj", "down_proj",
318
- ],
319
- lora_alpha=64,
320
- lora_dropout=0,
321
- bias="none",
322
- use_gradient_checkpointing="unsloth",
323
- random_state=3407,
324
- )
325
-
326
- # =========================
327
- # TRAINER
328
- # =========================
329
-
330
- dataset = build_dataset()
331
-
332
- trainer = GRPOTrainer(
333
- model=model,
334
- reward_funcs=[reward_environment],
335
- args=GRPOConfig(
336
- output_dir="outputs",
337
- learning_rate=2e-5,
338
- num_train_epochs=3,
339
- per_device_train_batch_size=2,
340
- gradient_accumulation_steps=4,
341
- num_generations=4,
342
- max_prompt_length=768,
343
- max_completion_length=128,
344
- logging_steps=5,
345
- warmup_ratio=0.1,
346
- bf16=True,
347
- report_to="none",
348
- ),
349
- train_dataset=dataset,
350
- tokenizer=tokenizer,
351
- )
352
-
353
- # =========================
354
- # RUN
355
- # =========================
356
-
357
- if __name__ == "__main__":
358
- ensure_env_ready()
359
-
360
- print("Starting GRPO training...")
361
- trainer.train()
362
-
363
- model.save_pretrained("outputs/lora_adapter")
364
- tokenizer.save_pretrained("outputs/lora_adapter")
365
- print("LoRA adapter saved to outputs/lora_adapter")
366
-
367
- print("Merging adapter into base model (bf16)...")
368
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
369
- model_name="outputs/lora_adapter",
370
- load_in_4bit=False,
371
- max_seq_length=2048,
372
- )
373
- merged_model.save_pretrained_merged(
374
- "outputs/merged",
375
- merged_tokenizer,
376
- save_method="merged_16bit",
377
- )
378
- print("Merged model saved to outputs/merged")
379
-
380
- if HF_REPO:
381
- print(f"Pushing merged model to {HF_REPO}...")
382
- merged_model.push_to_hub_merged(
383
- HF_REPO,
384
- merged_tokenizer,
385
- save_method="merged_16bit",
386
- token=HF_TOKEN,
387
- )
388
- print(f"Model live at https://huggingface.co/{HF_REPO}")
389
- else:
390
- print("Set HF_REPO env var to auto-push to Hub (skipped).")
391
-
 
392
  print("Done.")
 
1
+ # grpo_train.py
2
+
3
+ import os
4
+ import time
5
+ import json
6
+ import random
7
+ import requests
8
+ import torch
9
+
10
+ from datasets import Dataset
11
+ from unsloth import FastLanguageModel, PatchFastRL
12
+ from trl import GRPOTrainer, GRPOConfig
13
+
14
+ PatchFastRL("GRPO", FastLanguageModel)
15
+
16
+ # =========================
17
+ # CONFIG
18
+ # =========================
19
+
20
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
21
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
22
+ HF_REPO = os.getenv("HF_REPO", "") # e.g. "yourname/metaguard-llama3.1-8b-grpo"
23
+
24
+ ALLOWED_ACTIONS = [
25
+ "query_regulations",
26
+ "analyze_image",
27
+ "check_advertiser_history",
28
+ "request_landing_page",
29
+ "request_id_verification",
30
+ "submit_audit",
31
+ "approve",
32
+ "reject",
33
+ ]
34
+
35
+ # =========================
36
+ # HEALTH CHECK
37
+ # =========================
38
+
39
+ def ensure_env_ready():
40
+ for _ in range(20):
41
+ try:
42
+ r = requests.post(
43
+ f"{ENV_URL}/reset",
44
+ json={"task_id": "task_1_healthcare"},
45
+ timeout=5
46
+ )
47
+ if r.status_code == 200:
48
+ print("✅ Environment ready")
49
+ return
50
+ except:
51
+ pass
52
+ time.sleep(1)
53
+ raise RuntimeError("❌ ENV not reachable")
54
+
55
+ # =========================
56
+ # SAFE CLIENT
57
+ # =========================
58
+
59
+ class EnvClient:
60
+ def __init__(self, url):
61
+ self.url = url
62
+
63
+ def reset(self, task_id):
64
+ return requests.post(
65
+ f"{self.url}/reset",
66
+ json={"task_id": task_id},
67
+ timeout=8
68
+ ).json()
69
+
70
+ def step(self, action):
71
+ return requests.post(
72
+ f"{self.url}/step",
73
+ json={"action": action},
74
+ timeout=8
75
+ ).json()
76
+
77
+ def safe_step(client, action):
78
+ for _ in range(3):
79
+ try:
80
+ return client.step(action)
81
+ except:
82
+ time.sleep(0.5)
83
+ return {"reward": -0.3}
84
+
85
+ # =========================
86
+ # JSON PARSER
87
+ # =========================
88
+
89
+ def extract_json(text):
90
+ try:
91
+ if "```" in text:
92
+ text = text.split("```")[1]
93
+ if text.startswith("json"):
94
+ text = text[4:]
95
+ return json.loads(text.strip())
96
+ except:
97
+ return None
98
+
99
+ # =========================
100
+ # DATASET (WITH SETUP ACTIONS)
101
+ # =========================
102
+
103
+ BASE_SCENARIOS = [
104
+ # Phase 1 — Fresh state, expected: query_regulations
105
+ {
106
+ "task_id": "task_1_healthcare",
107
+ "text": "Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.",
108
+ "actions_already_taken": [],
109
+ "setup_actions": [],
110
+ },
111
+ {
112
+ "task_id": "task_2_financial",
113
+ "text": "Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.",
114
+ "actions_already_taken": [],
115
+ "setup_actions": [],
116
+ },
117
+ {
118
+ "task_id": "task_3_multimodal",
119
+ "text": "Multimodal ad: image may contain hidden violation. No actions taken yet.",
120
+ "actions_already_taken": [],
121
+ "setup_actions": [],
122
+ },
123
+
124
+ # Phase 2 — Policy checked, expected: analyze_image OR check_advertiser_history
125
+ {
126
+ "task_id": "task_1_healthcare",
127
+ "text": "Healthcare ad: pharma product. Policy already queried.",
128
+ "actions_already_taken": ["query_regulations"],
129
+ "setup_actions": [
130
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
131
+ ],
132
+ },
133
+ {
134
+ "task_id": "task_3_multimodal",
135
+ "text": "Multimodal ad: image not yet inspected. Policy already queried.",
136
+ "actions_already_taken": ["query_regulations"],
137
+ "setup_actions": [
138
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
139
+ ],
140
+ },
141
+
142
+ # Phase 3 — Policy + history checked, expected: submit_audit
143
+ {
144
+ "task_id": "task_2_financial",
145
+ "text": "Financial ad: investment scheme. Policy and advertiser history both checked.",
146
+ "actions_already_taken": ["query_regulations", "check_advertiser_history"],
147
+ "setup_actions": [
148
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
149
+ {"action_type": "check_advertiser_history", "reasoning": "trust score"},
150
+ ],
151
+ },
152
+
153
+ # Phase 4 — Audit complete, expected: reject (high-risk) or approve (clean)
154
+ {
155
+ "task_id": "task_2_financial",
156
+ "text": "Financial ad: investment scheme. Policy, history, and audit all complete. Make final decision.",
157
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "submit_audit"],
158
+ "setup_actions": [
159
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
160
+ {"action_type": "check_advertiser_history", "reasoning": "trust score"},
161
+ {"action_type": "submit_audit", "reasoning": "audit log"},
162
+ ],
163
+ },
164
+
165
+ # Targeting task — fresh state, expected: query_regulations
166
+ {
167
+ "task_id": "task_4_targeting",
168
+ "text": "Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.",
169
+ "actions_already_taken": [],
170
+ "setup_actions": [],
171
+ },
172
+ # Targeting task — mid state, expected: request_id_verification (age check)
173
+ {
174
+ "task_id": "task_4_targeting",
175
+ "text": "Financial ad targeting young users. Policy queried, need to verify age targeting.",
176
+ "actions_already_taken": ["query_regulations"],
177
+ "setup_actions": [
178
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
179
+ ],
180
+ },
181
+ # Targeting task — audit ready
182
+ {
183
+ "task_id": "task_4_targeting",
184
+ "text": "Financial ad targeting minors. Policy, advertiser history, and ID verification done. Submit audit.",
185
+ "actions_already_taken": ["query_regulations", "check_advertiser_history", "request_id_verification"],
186
+ "setup_actions": [
187
+ {"action_type": "query_regulations", "reasoning": "policy lookup"},
188
+ {"action_type": "check_advertiser_history", "reasoning": "trust score"},
189
+ {"action_type": "request_id_verification", "reasoning": "age check"},
190
+ ],
191
+ },
192
+ ]
193
+
194
+ PROMPT_TEMPLATE = """You are an enterprise Ad Policy Compliance Agent.
195
+
196
+ You MUST choose exactly ONE action_type from this list (any other value is invalid):
197
+ - query_regulations
198
+ - analyze_image
199
+ - check_advertiser_history
200
+ - request_landing_page
201
+ - request_id_verification
202
+ - submit_audit
203
+ - approve
204
+ - reject
205
+
206
+ REQUIRED PHASE ORDER:
207
+ 1. query_regulations -> always first
208
+ 2. analyze_image / check_advertiser_history -> gather signals
209
+ 3. submit_audit -> always before final decision
210
+ 4. approve OR reject -> only after audit
211
+
212
+ HARD RULES:
213
+ - NEVER repeat an action listed in `actions_already_taken`.
214
+ - Respond with ONLY a valid JSON object. No markdown, no prose.
215
+
216
+ Required format:
217
+ {{"action_type": "<one_of_the_actions_above>", "reasoning": "<short reason>"}}
218
+
219
+ Scenario: {text}
220
+ actions_already_taken: {actions_already_taken}
221
+
222
+ Your next action?"""
223
+
224
+
225
+ def build_dataset():
226
+ rows = []
227
+ for s in BASE_SCENARIOS:
228
+ prompt = PROMPT_TEMPLATE.format(
229
+ text=s["text"],
230
+ actions_already_taken=json.dumps(s["actions_already_taken"]),
231
+ )
232
+ rows.append({
233
+ "prompt": prompt,
234
+ "task_id": s["task_id"],
235
+ "setup_actions": s["setup_actions"],
236
+ })
237
+ return Dataset.from_list(rows * 10) # 10 scenarios x 10 = 100 examples
238
+
239
+ # =========================
240
+ # REWARD FUNCTION (FIXED)
241
+ # =========================
242
+
243
+ def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
244
+ """Shaped reward for GRPO.
245
+
246
+ Pure env reward is too sparse (mostly -0.05) to give clear gradients.
247
+ We add explicit shaping:
248
+ - invalid JSON / invalid action_type -> -1.0 (strong negative signal)
249
+ - valid action env REJECTS (wrong phase / API failure) -> -0.5
250
+ - valid action env ACCEPTS (advances state) -> +0.5 + env_reward
251
+ - terminal correct decision -> env_reward already contains +1.0 bonus
252
+ """
253
+ client = EnvClient(ENV_URL)
254
+ rewards = []
255
+
256
+ for completion, t_id, setup in zip(completions, task_id, setup_actions):
257
+ parsed = extract_json(completion)
258
+ if not parsed:
259
+ rewards.append(-1.0)
260
+ continue
261
+
262
+ action_type = parsed.get("action_type")
263
+ if action_type not in ALLOWED_ACTIONS:
264
+ rewards.append(-1.0)
265
+ continue
266
+
267
+ action = {
268
+ "action_type": action_type,
269
+ "reasoning": parsed.get("reasoning", "format-compliant"),
270
+ }
271
+
272
+ try:
273
+ client.reset(t_id)
274
+ for s in setup:
275
+ safe_step(client, s)
276
+
277
+ result = safe_step(client, action)
278
+ env_reward = float(result.get("reward", -0.2))
279
+ status_msg = (result.get("status_message") or "").lower()
280
+
281
+ rejected = (
282
+ "api failure" in status_msg
283
+ or "invalid action" in status_msg
284
+ or "must call" in status_msg
285
+ )
286
+
287
+ if rejected:
288
+ shaped = -0.5
289
+ else:
290
+ shaped = 0.5 + env_reward
291
+
292
+ rewards.append(shaped)
293
+
294
+ except Exception:
295
+ rewards.append(-0.3)
296
+
297
+ return rewards
298
+
299
+ # =========================
300
+ # MODEL
301
+ # =========================
302
+
303
+ USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 40 * 1024**3
304
+
305
+ model, tokenizer = FastLanguageModel.from_pretrained(
306
+ model_name="unsloth/Llama-3.1-8B-Instruct",
307
+ load_in_4bit=USE_4BIT,
308
+ max_seq_length=2048,
309
+ dtype=None, # auto-detect bf16 on A100
310
+ )
311
+
312
+ model = FastLanguageModel.get_peft_model(
313
+ model,
314
+ r=32,
315
+ target_modules=[
316
+ "q_proj", "k_proj", "v_proj", "o_proj",
317
+ "gate_proj", "up_proj", "down_proj",
318
+ ],
319
+ lora_alpha=64,
320
+ lora_dropout=0,
321
+ bias="none",
322
+ use_gradient_checkpointing="unsloth",
323
+ random_state=3407,
324
+ )
325
+
326
+ # =========================
327
+ # TRAINER
328
+ # =========================
329
+
330
+ dataset = build_dataset()
331
+
332
+ trainer = GRPOTrainer(
333
+ model=model,
334
+ reward_funcs=[reward_environment],
335
+ args=GRPOConfig(
336
+ output_dir="outputs",
337
+ learning_rate=2e-5,
338
+ num_train_epochs=3,
339
+ per_device_train_batch_size=2,
340
+ gradient_accumulation_steps=4,
341
+ num_generations=4,
342
+ max_prompt_length=768,
343
+ max_completion_length=128,
344
+ logging_steps=5,
345
+ warmup_ratio=0.1,
346
+ bf16=True,
347
+ fp16=false,
348
+ report_to="none",
349
+ ),
350
+ train_dataset=dataset,
351
+ tokenizer=tokenizer,
352
+ )
353
+
354
+ # =========================
355
+ # RUN
356
+ # =========================
357
+
358
+ if __name__ == "__main__":
359
+ ensure_env_ready()
360
+
361
+ print("Starting GRPO training...")
362
+ trainer.train()
363
+
364
+ model.save_pretrained("outputs/lora_adapter")
365
+ tokenizer.save_pretrained("outputs/lora_adapter")
366
+ print("LoRA adapter saved to outputs/lora_adapter")
367
+
368
+ print("Merging adapter into base model (bf16)...")
369
+ merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
370
+ model_name="outputs/lora_adapter",
371
+ load_in_4bit=False,
372
+ max_seq_length=2048,
373
+ )
374
+ merged_model.save_pretrained_merged(
375
+ "outputs/merged",
376
+ merged_tokenizer,
377
+ save_method="merged_16bit",
378
+ )
379
+ print("Merged model saved to outputs/merged")
380
+
381
+ if HF_REPO:
382
+ print(f"Pushing merged model to {HF_REPO}...")
383
+ merged_model.push_to_hub_merged(
384
+ HF_REPO,
385
+ merged_tokenizer,
386
+ save_method="merged_16bit",
387
+ token=HF_TOKEN,
388
+ )
389
+ print(f"Model live at https://huggingface.co/{HF_REPO}")
390
+ else:
391
+ print("Set HF_REPO env var to auto-push to Hub (skipped).")
392
+
393
  print("Done.")