Kartik Goyal commited on
Commit
39124e2
Β·
1 Parent(s): d3424a0

added grpo colab file

Browse files
Files changed (1) hide show
  1. grpo_train.ipynb +540 -0
grpo_train.ipynb ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "A100"
8
+ },
9
+ "kernelspec": {
10
+ "display_name": "Python 3",
11
+ "name": "python3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "# πŸ›‘οΈ MetaGuard β€” GRPO Training Notebook\n",
24
+ "\n",
25
+ "**Team:** Parth Singhal, Mehakveer Kaur, Kartik Goyal \n",
26
+ "**HF Space:** https://huggingface.co/spaces/parth-1/MetaGuard \n",
27
+ "**Hackathon:** OpenEnv β€” Meta Γ— Scaler \n",
28
+ "\n",
29
+ "This notebook trains **Llama 3.1 8B** using GRPO on the MetaGuard Ad Policy Compliance environment.\n",
30
+ "\n",
31
+ "### What this trains:\n",
32
+ "- Agent learns to follow structured SOP: `query_regulations β†’ gather signals β†’ submit_audit β†’ decide`\n",
33
+ "- Reward shaped by correctness, sequence compliance, API failure recovery\n",
34
+ "- Environment hosted on HF Space (A100 runs both training + env)"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "## Cell 1 β€” Install Dependencies"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "!pip install unsloth trl transformers datasets accelerate peft -q\n",
51
+ "!pip install openenv-core==0.2.1 --no-deps -q\n",
52
+ "!pip install fastapi uvicorn pydantic requests openai matplotlib -q\n",
53
+ "print('βœ… Dependencies installed')"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "metadata": {},
59
+ "source": [
60
+ "## Cell 2 β€” Clone Repo"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "import os\n",
70
+ "\n",
71
+ "if not os.path.exists('meta-ad-policy-sandbox'):\n",
72
+ " !git clone https://github.com/Parth380/meta-ad-policy-sandbox.git\n",
73
+ "\n",
74
+ "%cd meta-ad-policy-sandbox\n",
75
+ "print('βœ… Repo ready')"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {},
81
+ "source": [
82
+ "## Cell 3 β€” Config (SET THESE)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "import os\n",
92
+ "\n",
93
+ "os.environ['ENV_URL'] = 'https://parth-1-metaguard.hf.space' # your HF Space URL\n",
94
+ "os.environ['HF_REPO'] = 'parth-1/metaguard-llama3.1-8b-grpo' # model push destination\n",
95
+ "os.environ['HF_TOKEN'] = '' # ← paste your HF write token here\n",
96
+ "\n",
97
+ "ENV_URL = os.environ['ENV_URL']\n",
98
+ "HF_TOKEN = os.environ['HF_TOKEN']\n",
99
+ "HF_REPO = os.environ['HF_REPO']\n",
100
+ "\n",
101
+ "print(f'ENV_URL : {ENV_URL}')\n",
102
+ "print(f'HF_REPO : {HF_REPO}')\n",
103
+ "print(f'HF_TOKEN : {\"set βœ…\" if HF_TOKEN else \"MISSING ❌\"}')"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "metadata": {},
109
+ "source": [
110
+ "## Cell 4 β€” Wake Up HF Space"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "import requests, time\n",
120
+ "\n",
121
+ "print('Waking up HF Space...')\n",
122
+ "for i in range(20):\n",
123
+ " try:\n",
124
+ " r = requests.post(\n",
125
+ " f\"{ENV_URL}/reset\",\n",
126
+ " json={'task_id': 'task_1_healthcare'},\n",
127
+ " timeout=10\n",
128
+ " )\n",
129
+ " if r.status_code == 200:\n",
130
+ " print(f'βœ… Environment ready (attempt {i+1})')\n",
131
+ " break\n",
132
+ " except Exception as e:\n",
133
+ " print(f' attempt {i+1}: waiting... ({e})')\n",
134
+ " time.sleep(3)\n",
135
+ "else:\n",
136
+ " raise RuntimeError('❌ ENV not reachable after 20 attempts')"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "## Cell 5 β€” Imports + Helpers"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "import json\n",
153
+ "import random\n",
154
+ "import torch\n",
155
+ "import matplotlib.pyplot as plt\n",
156
+ "from collections import defaultdict\n",
157
+ "\n",
158
+ "from datasets import Dataset\n",
159
+ "from unsloth import FastLanguageModel, PatchFastRL\n",
160
+ "from trl import GRPOTrainer, GRPOConfig\n",
161
+ "\n",
162
+ "PatchFastRL('GRPO', FastLanguageModel)\n",
163
+ "\n",
164
+ "ALLOWED_ACTIONS = [\n",
165
+ " 'query_regulations', 'analyze_image', 'check_advertiser_history',\n",
166
+ " 'request_landing_page', 'request_id_verification',\n",
167
+ " 'submit_audit', 'approve', 'reject',\n",
168
+ "]\n",
169
+ "\n",
170
+ "class EnvClient:\n",
171
+ " def __init__(self, url):\n",
172
+ " self.url = url\n",
173
+ " def reset(self, task_id):\n",
174
+ " return requests.post(f'{self.url}/reset', json={'task_id': task_id}, timeout=8).json()\n",
175
+ " def step(self, action):\n",
176
+ " return requests.post(f'{self.url}/step', json={'action': action}, timeout=8).json()\n",
177
+ "\n",
178
+ "def safe_step(client, action):\n",
179
+ " for _ in range(3):\n",
180
+ " try:\n",
181
+ " return client.step(action)\n",
182
+ " except:\n",
183
+ " time.sleep(0.5)\n",
184
+ " return {'reward': -0.3}\n",
185
+ "\n",
186
+ "def extract_json(text):\n",
187
+ " try:\n",
188
+ " if '```' in text:\n",
189
+ " text = text.split('```')[1]\n",
190
+ " if text.startswith('json'):\n",
191
+ " text = text[4:]\n",
192
+ " return json.loads(text.strip())\n",
193
+ " except:\n",
194
+ " return None\n",
195
+ "\n",
196
+ "print('βœ… Helpers loaded')"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "metadata": {},
202
+ "source": [
203
+ "## Cell 6 β€” Dataset"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "BASE_SCENARIOS = [\n",
213
+ " {'task_id': 'task_1_healthcare',\n",
214
+ " 'text': \"Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.\",\n",
215
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
216
+ " {'task_id': 'task_2_financial',\n",
217
+ " 'text': \"Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.\",\n",
218
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
219
+ " {'task_id': 'task_3_multimodal',\n",
220
+ " 'text': 'Multimodal ad: image may contain hidden violation. No actions taken yet.',\n",
221
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
222
+ " {'task_id': 'task_6_conflict',\n",
223
+ " 'text': 'High-trust advertiser but policy borderline. No actions taken yet.',\n",
224
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
225
+ " {'task_id': 'task_7_ambiguous',\n",
226
+ " 'text': 'Ambiguous wellness ad. Policy confidence low. No actions taken yet.',\n",
227
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
228
+ " {'task_id': 'task_8_adversarial',\n",
229
+ " 'text': 'Natural supplement ad. Image may hide violation. No actions taken yet.',\n",
230
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
231
+ " {'task_id': 'task_9_dependency_trap',\n",
232
+ " 'text': 'Certified wellness product. Regulations and CRM look clean. No actions taken yet.',\n",
233
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
234
+ " {'task_id': 'task_10_failure',\n",
235
+ " 'text': 'Healthcare ad. First API call may fail. No actions taken yet.',\n",
236
+ " 'actions_already_taken': [], 'setup_actions': []},\n",
237
+ " {'task_id': 'task_1_healthcare',\n",
238
+ " 'text': 'Healthcare ad. Policy already queried.',\n",
239
+ " 'actions_already_taken': ['query_regulations'],\n",
240
+ " 'setup_actions': [{'action_type': 'query_regulations', 'reasoning': 'policy lookup'}]},\n",
241
+ " {'task_id': 'task_2_financial',\n",
242
+ " 'text': 'Financial ad. Policy and history checked. Submit audit next.',\n",
243
+ " 'actions_already_taken': ['query_regulations', 'check_advertiser_history'],\n",
244
+ " 'setup_actions': [\n",
245
+ " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n",
246
+ " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n",
247
+ " ]},\n",
248
+ " {'task_id': 'task_2_financial',\n",
249
+ " 'text': 'Financial ad. Policy, history, audit all complete. Make final decision.',\n",
250
+ " 'actions_already_taken': ['query_regulations', 'check_advertiser_history', 'submit_audit'],\n",
251
+ " 'setup_actions': [\n",
252
+ " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n",
253
+ " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n",
254
+ " {'action_type': 'submit_audit', 'reasoning': 'audit log'},\n",
255
+ " ]},\n",
256
+ "]\n",
257
+ "\n",
258
+ "PROMPT_TEMPLATE = \"\"\"You are an enterprise Ad Policy Compliance Agent.\n",
259
+ "\n",
260
+ "You MUST choose exactly ONE action_type from this list (any other value is invalid):\n",
261
+ "- query_regulations\n",
262
+ "- analyze_image\n",
263
+ "- check_advertiser_history\n",
264
+ "- request_landing_page\n",
265
+ "- request_id_verification\n",
266
+ "- submit_audit\n",
267
+ "- approve\n",
268
+ "- reject\n",
269
+ "\n",
270
+ "REQUIRED PHASE ORDER:\n",
271
+ "1. query_regulations -> always first\n",
272
+ "2. analyze_image / check_advertiser_history -> gather signals\n",
273
+ "3. submit_audit -> always before final decision\n",
274
+ "4. approve OR reject -> only after audit\n",
275
+ "\n",
276
+ "HARD RULES:\n",
277
+ "- NEVER repeat an action listed in `actions_already_taken`.\n",
278
+ "- Respond with ONLY a valid JSON object. No markdown, no prose.\n",
279
+ "\n",
280
+ "Required format:\n",
281
+ "{{\\\"action_type\\\": \\\"<one_of_the_actions_above>\\\", \\\"reasoning\\\": \\\"<short reason>\\\"}}\n",
282
+ "\n",
283
+ "Scenario: {text}\n",
284
+ "actions_already_taken: {actions_already_taken}\n",
285
+ "\n",
286
+ "Your next action?\"\"\"\n",
287
+ "\n",
288
+ "def build_dataset():\n",
289
+ " rows = []\n",
290
+ " for s in BASE_SCENARIOS:\n",
291
+ " prompt = PROMPT_TEMPLATE.format(\n",
292
+ " text=s['text'],\n",
293
+ " actions_already_taken=json.dumps(s['actions_already_taken']),\n",
294
+ " )\n",
295
+ " rows.append({'prompt': prompt, 'task_id': s['task_id'], 'setup_actions': s['setup_actions']})\n",
296
+ " return Dataset.from_list(rows * 10)\n",
297
+ "\n",
298
+ "dataset = build_dataset()\n",
299
+ "print(f'βœ… Dataset: {len(dataset)} examples')"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "markdown",
304
+ "metadata": {},
305
+ "source": [
306
+ "## Cell 7 β€” Reward Function"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "# Track rewards for plotting\n",
316
+ "reward_log = []\n",
317
+ "step_log = []\n",
318
+ "global_step_counter = [0]\n",
319
+ "\n",
320
+ "def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):\n",
321
+ " client = EnvClient(ENV_URL)\n",
322
+ " rewards = []\n",
323
+ "\n",
324
+ " for completion, t_id, setup in zip(completions, task_id, setup_actions):\n",
325
+ " parsed = extract_json(completion)\n",
326
+ " if not parsed:\n",
327
+ " rewards.append(-1.0)\n",
328
+ " continue\n",
329
+ "\n",
330
+ " action_type = parsed.get('action_type')\n",
331
+ " if action_type not in ALLOWED_ACTIONS:\n",
332
+ " rewards.append(-1.0)\n",
333
+ " continue\n",
334
+ "\n",
335
+ " action = {\n",
336
+ " 'action_type': action_type,\n",
337
+ " 'reasoning': parsed.get('reasoning', 'format-compliant'),\n",
338
+ " }\n",
339
+ "\n",
340
+ " try:\n",
341
+ " client.reset(t_id)\n",
342
+ " for s in setup:\n",
343
+ " safe_step(client, s)\n",
344
+ "\n",
345
+ " result = safe_step(client, action)\n",
346
+ " env_reward = float(result.get('reward', -0.2))\n",
347
+ " status_msg = (result.get('status_message') or '').lower()\n",
348
+ "\n",
349
+ " rejected = (\n",
350
+ " 'api failure' in status_msg\n",
351
+ " or 'invalid action' in status_msg\n",
352
+ " or 'must call' in status_msg\n",
353
+ " )\n",
354
+ " shaped = -0.5 if rejected else 0.5 + env_reward\n",
355
+ " rewards.append(shaped)\n",
356
+ "\n",
357
+ " except Exception:\n",
358
+ " rewards.append(-0.3)\n",
359
+ "\n",
360
+ " # Log for plot\n",
361
+ " avg = sum(rewards) / len(rewards) if rewards else 0.0\n",
362
+ " global_step_counter[0] += 1\n",
363
+ " reward_log.append(avg)\n",
364
+ " step_log.append(global_step_counter[0])\n",
365
+ "\n",
366
+ " return rewards\n",
367
+ "\n",
368
+ "print('βœ… Reward function ready')"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "metadata": {},
374
+ "source": [
375
+ "## Cell 8 β€” Load Model"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": null,
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 40 * 1024**3\n",
385
+ "print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"CPU\"}')\n",
386
+ "print(f'4-bit quant: {USE_4BIT}')\n",
387
+ "\n",
388
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
389
+ " model_name='unsloth/Llama-3.1-8B-Instruct',\n",
390
+ " load_in_4bit=USE_4BIT,\n",
391
+ " max_seq_length=2048,\n",
392
+ " dtype=None,\n",
393
+ ")\n",
394
+ "\n",
395
+ "model = FastLanguageModel.get_peft_model(\n",
396
+ " model,\n",
397
+ " r=32,\n",
398
+ " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
399
+ " lora_alpha=64,\n",
400
+ " lora_dropout=0,\n",
401
+ " bias='none',\n",
402
+ " use_gradient_checkpointing='unsloth',\n",
403
+ " random_state=3407,\n",
404
+ ")\n",
405
+ "print('βœ… Model loaded')"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "markdown",
410
+ "metadata": {},
411
+ "source": [
412
+ "## Cell 9 β€” Train"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": [
421
+ "trainer = GRPOTrainer(\n",
422
+ " model=model,\n",
423
+ " reward_funcs=[reward_environment],\n",
424
+ " args=GRPOConfig(\n",
425
+ " output_dir='outputs',\n",
426
+ " learning_rate=2e-5,\n",
427
+ " num_train_epochs=3,\n",
428
+ " per_device_train_batch_size=2,\n",
429
+ " gradient_accumulation_steps=4,\n",
430
+ " num_generations=4,\n",
431
+ " max_prompt_length=768,\n",
432
+ " max_completion_length=128,\n",
433
+ " logging_steps=5,\n",
434
+ " warmup_ratio=0.1,\n",
435
+ " bf16=True,\n",
436
+ " report_to='none',\n",
437
+ " ),\n",
438
+ " train_dataset=dataset,\n",
439
+ " tokenizer=tokenizer,\n",
440
+ ")\n",
441
+ "\n",
442
+ "print('πŸš€ Starting GRPO training...')\n",
443
+ "trainer.train()\n",
444
+ "print('βœ… Training complete')"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "markdown",
449
+ "metadata": {},
450
+ "source": [
451
+ "## Cell 10 β€” Plot Reward Curve"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": null,
457
+ "metadata": {},
458
+ "outputs": [],
459
+ "source": [
460
+ "import matplotlib.pyplot as plt\n",
461
+ "import numpy as np\n",
462
+ "\n",
463
+ "# Smooth with moving average\n",
464
+ "def moving_avg(data, window=5):\n",
465
+ " if len(data) < window:\n",
466
+ " return data\n",
467
+ " return np.convolve(data, np.ones(window)/window, mode='valid')\n",
468
+ "\n",
469
+ "fig, ax = plt.subplots(figsize=(10, 5))\n",
470
+ "ax.plot(step_log, reward_log, alpha=0.3, color='steelblue', label='Raw reward')\n",
471
+ "smoothed = moving_avg(reward_log)\n",
472
+ "ax.plot(range(len(smoothed)), smoothed, color='steelblue', linewidth=2, label='Smoothed (MA-5)')\n",
473
+ "ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n",
474
+ "ax.set_xlabel('Training Step')\n",
475
+ "ax.set_ylabel('Avg Reward per Batch')\n",
476
+ "ax.set_title('MetaGuard GRPO β€” Reward Curve')\n",
477
+ "ax.legend()\n",
478
+ "ax.grid(alpha=0.3)\n",
479
+ "\n",
480
+ "plt.tight_layout()\n",
481
+ "plt.savefig('outputs/reward_plot.png', dpi=150)\n",
482
+ "plt.show()\n",
483
+ "print('βœ… Plot saved to outputs/reward_plot.png')\n",
484
+ "\n",
485
+ "# Print before/after summary\n",
486
+ "n = len(reward_log)\n",
487
+ "first_10 = reward_log[:min(10, n)]\n",
488
+ "last_10 = reward_log[max(0, n-10):]\n",
489
+ "print(f'\\n--- Results ---')\n",
490
+ "print(f'Avg reward (first 10 steps): {sum(first_10)/len(first_10):.3f}')\n",
491
+ "print(f'Avg reward (last 10 steps) : {sum(last_10)/len(last_10):.3f}')"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "markdown",
496
+ "metadata": {},
497
+ "source": [
498
+ "## Cell 11 β€” Save + Push to HF Hub"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "execution_count": null,
504
+ "metadata": {},
505
+ "outputs": [],
506
+ "source": [
507
+ "model.save_pretrained('outputs/lora_adapter')\n",
508
+ "tokenizer.save_pretrained('outputs/lora_adapter')\n",
509
+ "print('βœ… LoRA adapter saved')\n",
510
+ "\n",
511
+ "print('Merging adapter into base model (bf16)...')\n",
512
+ "merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n",
513
+ " model_name='outputs/lora_adapter',\n",
514
+ " load_in_4bit=False,\n",
515
+ " max_seq_length=2048,\n",
516
+ ")\n",
517
+ "merged_model.save_pretrained_merged(\n",
518
+ " 'outputs/merged',\n",
519
+ " merged_tokenizer,\n",
520
+ " save_method='merged_16bit',\n",
521
+ ")\n",
522
+ "print('βœ… Merged model saved')\n",
523
+ "\n",
524
+ "if HF_REPO and HF_TOKEN:\n",
525
+ " print(f'Pushing to {HF_REPO}...')\n",
526
+ " merged_model.push_to_hub_merged(\n",
527
+ " HF_REPO,\n",
528
+ " merged_tokenizer,\n",
529
+ " save_method='merged_16bit',\n",
530
+ " token=HF_TOKEN,\n",
531
+ " )\n",
532
+ " print(f'βœ… Model live at https://huggingface.co/{HF_REPO}')\n",
533
+ "else:\n",
534
+ " print('⚠️ Set HF_REPO and HF_TOKEN to push to Hub')\n",
535
+ "\n",
536
+ "print('Done.')"
537
+ ]
538
+ }
539
+ ]
540
+ }