Ev3Dev commited on
Commit
ad39f2a
·
verified ·
1 Parent(s): db03c40

Upload folder using huggingface_hub

Browse files
H100_JUPYTER_SETUP.md CHANGED
@@ -4,7 +4,7 @@ This guide walks you through setting up the OpenEnv Bio Experiment environment o
4
 
5
  ## Prerequisites
6
 
7
- - **Python** 3.10 (3.10, 3.11, or 3.12 recommended)
8
  - **uv** – fast Python package manager ([install instructions](#installing-uv))
9
  - **NVIDIA driver** ≥ 535.104.05 (usually pre-installed on H100 instances)
10
  - **CUDA** – H100 uses CUDA 12.x; PyTorch wheels bundle the runtime, so a separate CUDA Toolkit is not required
@@ -38,14 +38,16 @@ cd OpenENV-Hackathon
38
 
39
  ### 2. Use uv's auto PyTorch backend
40
 
41
- uv can detect your GPU and pick the right PyTorch build. For H100 (CUDA 12.x):
42
 
43
  ```bash
44
- # Install everything: core + training (TRL, transformers, torch, unsloth) + Jupyter
45
  UV_TORCH_BACKEND=cu128 uv sync --extra train
46
 
47
- # Add Jupyter kernel support
48
- uv add ipykernel jupyter --extra train
 
 
49
  ```
50
 
51
  If `UV_TORCH_BACKEND=cu128` fails (e.g., cu128 wheels not available yet), try:
@@ -57,12 +59,20 @@ UV_TORCH_BACKEND=cu126 uv sync --extra train
57
  ### 3. Register the environment as a Jupyter kernel
58
 
59
  ```bash
60
- uv run python -m ipykernel install --user --name openenv-bio --display-name "OpenEnv Bio (H100)"
 
 
 
 
 
 
61
  ```
62
 
 
 
63
  ### 4. Verify CUDA
64
 
65
- In a new Jupyter notebook, select the **"OpenEnv Bio (H100)"** kernel and run:
66
 
67
  ```python
68
  import torch
@@ -108,7 +118,6 @@ Then run:
108
 
109
  ```bash
110
  uv sync --extra train
111
- uv add ipykernel jupyter --extra train
112
  ```
113
 
114
  For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
@@ -128,20 +137,20 @@ For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
128
  On H100, use the quantized Unsloth entrypoints:
129
 
130
  ```bash
131
- uv run python training_unsloth.py --dry-run
132
- uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B --output-dir training/grpo-unsloth-output
133
  uv run python run_agent_unsloth.py
134
  ```
135
 
136
- The checked-in `inference.ipynb` notebook now uses `training_unsloth.py` helpers with 4-bit loading and fast inference enabled by default.
137
 
138
  ## Running Training in a Jupyter Notebook
139
 
140
  Example cell:
141
 
142
  ```python
143
- # In a notebook with the OpenEnv Bio (H100) kernel
144
- !uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B --dry-run
145
  ```
146
 
147
  Or run interactively from Python:
@@ -150,8 +159,8 @@ Or run interactively from Python:
150
  import subprocess
151
  subprocess.run([
152
  "uv", "run", "python", "training_unsloth.py",
153
- "--model-id", "Qwen/Qwen3.5-4B",
154
- "--output-dir", "training/grpo-unsloth-output",
155
  ], check=True)
156
  ```
157
 
@@ -159,16 +168,32 @@ subprocess.run([
159
 
160
  | Component | Version / Notes |
161
  |----------------|------------------------------------------------------|
162
- | Python | 3.10 |
163
  | uv | ≥ 0.5.3 (for PyTorch index support) |
164
  | torch | ≥ 2.10.0 (cu128 or cu126 for H100) |
165
- | transformers | ≥ 5.3.0 |
166
  | trl | ≥ 0.29.0 |
167
  | accelerate | ≥ 1.13.0 |
168
  | Jupyter | Optional, for notebook workflows |
169
 
170
  ## Troubleshooting
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  ### `torch.cuda.is_available()` is False
173
 
174
  - Confirm the Jupyter kernel is the one where you ran `uv sync` (the one with `ipykernel`).
@@ -193,6 +218,51 @@ export HF_HUB_DISABLE_SYMLINKS_WARNING=1
193
  - Use a smaller model (e.g., `Qwen/Qwen3.5-0.8B`) for experiments.
194
  - Keep `--disable-4bit` off unless you explicitly need wider weights.
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  ## See Also
197
 
198
  - Main [README.md](README.md) for project overview, APIs, and usage
 
4
 
5
  ## Prerequisites
6
 
7
+ - **Python** 3.10, 3.11, or **3.12** (3.12 recommended for H100; 3.13 is not supported—numba, vllm, and others require <3.13)
8
  - **uv** – fast Python package manager ([install instructions](#installing-uv))
9
  - **NVIDIA driver** ≥ 535.104.05 (usually pre-installed on H100 instances)
10
  - **CUDA** – H100 uses CUDA 12.x; PyTorch wheels bundle the runtime, so a separate CUDA Toolkit is not required
 
38
 
39
  ### 2. Use uv's auto PyTorch backend
40
 
41
+ The project uses Python 3.12 (see `.python-version`). uv will create a 3.12 venv. For H100 (CUDA 12.x):
42
 
43
  ```bash
44
+ # Install everything: core + training (TRL, transformers, torch) + Jupyter
45
  UV_TORCH_BACKEND=cu128 uv sync --extra train
46
 
47
+ # Add Unsloth for training_unsloth.py (skips trl downgrade; Unsloth works with TRL 0.29)
48
+ uv pip install unsloth unsloth_zoo --no-deps
49
+
50
+ # (ipykernel is included in --extra train)
51
  ```
52
 
53
  If `UV_TORCH_BACKEND=cu128` fails (e.g., cu128 wheels not available yet), try:
 
59
  ### 3. Register the environment as a Jupyter kernel
60
 
61
  ```bash
62
+ uv run python -m ipykernel install --user --name openenv-bio-312 --display-name "OpenEnv Bio (Python 3.12)"
63
+ ```
64
+
65
+ Or run the helper script (from project root):
66
+
67
+ ```bash
68
+ bash scripts/register_kernel_312.sh
69
  ```
70
 
71
+ Then select **"OpenEnv Bio (Python 3.12)"** in the notebook kernel picker.
72
+
73
  ### 4. Verify CUDA
74
 
75
+ In a new Jupyter notebook, select the **"OpenEnv Bio (Python 3.12)"** kernel and run:
76
 
77
  ```python
78
  import torch
 
118
 
119
  ```bash
120
  uv sync --extra train
 
121
  ```
122
 
123
  For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
 
137
  On H100, use the quantized Unsloth entrypoints:
138
 
139
  ```bash
140
+ uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b --dry-run
141
+ uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b
142
  uv run python run_agent_unsloth.py
143
  ```
144
 
145
+ The checked-in `inference.ipynb` notebook uses `training_unsloth.py` helpers with 4-bit loading. vLLM fast inference is disabled to avoid dependency conflicts.
146
 
147
  ## Running Training in a Jupyter Notebook
148
 
149
  Example cell:
150
 
151
  ```python
152
+ # In a notebook with the OpenEnv Bio (Python 3.12) kernel
153
+ !uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b --dry-run
154
  ```
155
 
156
  Or run interactively from Python:
 
159
  import subprocess
160
  subprocess.run([
161
  "uv", "run", "python", "training_unsloth.py",
162
+ "--model-id", "Qwen/Qwen3-4B-Base",
163
+ "--output-dir", "training/grpo-unsloth-qwen3-4b",
164
  ], check=True)
165
  ```
166
 
 
168
 
169
  | Component | Version / Notes |
170
  |----------------|------------------------------------------------------|
171
+ | Python | 3.10–3.12 (3.12 recommended; 3.13 not supported) |
172
  | uv | ≥ 0.5.3 (for PyTorch index support) |
173
  | torch | ≥ 2.10.0 (cu128 or cu126 for H100) |
174
+ | transformers | ≥4.57 (with unsloth≥2025.10.14) |
175
  | trl | ≥ 0.29.0 |
176
  | accelerate | ≥ 1.13.0 |
177
  | Jupyter | Optional, for notebook workflows |
178
 
179
  ## Troubleshooting
180
 
181
+ ### `RuntimeError: Cannot install on Python version 3.13.x` or numba / setup.py errors
182
+
183
+ Python 3.13 is not supported (numba, vllm, and other deps require <3.13). Use Python 3.12:
184
+
185
+ ```bash
186
+ # With uv: ensure Python 3.12 is available, then sync
187
+ uv python install 3.12
188
+ uv sync --extra train
189
+
190
+ # Or create venv explicitly with 3.12
191
+ uv venv --python 3.12
192
+ UV_TORCH_BACKEND=cu128 uv sync --extra train
193
+ ```
194
+
195
+ The project's `.python-version` file pins 3.12; uv will use it when creating the venv.
196
+
197
  ### `torch.cuda.is_available()` is False
198
 
199
  - Confirm the Jupyter kernel is the one where you ran `uv sync` (the one with `ipykernel`).
 
218
  - Use a smaller model (e.g., `Qwen/Qwen3.5-0.8B`) for experiments.
219
  - Keep `--disable-4bit` off unless you explicitly need wider weights.
220
 
221
+ ### `ModuleNotFoundError: No module named 'vllm.lora.models'`
222
+
223
+ Unsloth's `unsloth_zoo` imports vLLM at load time and expects `vllm.lora.models`, which some vLLM versions don't have. Fix by installing a compatible vLLM:
224
+
225
+ ```bash
226
+ pip install "vllm==0.8.2"
227
+ # or
228
+ pip install "vllm==0.7.3"
229
+ ```
230
+
231
+ **Note:** vLLM 0.8.2 pins `torch==2.6.0`, which conflicts with this project's `torch>=2.10.0`. If you hit that conflict:
232
+
233
+ 1. Use a **separate environment** with torch 2.6–2.8 + vllm 0.8.2 + unsloth.
234
+ 2. Or use the non-Unsloth path (`training_script.py` / `train.ipynb`) which doesn't depend on vLLM.
235
+
236
+ ### `KeyError: 'qwen3_5'` / Qwen3.5 not supported
237
+
238
+ Qwen3.5 requires transformers 5.x. With transformers 4.57, use **Qwen2.5** instead:
239
+ - `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`
240
+ - `unsloth/Qwen2.5-7B-Instruct-bnb-4bit`
241
+ - `Qwen/Qwen2.5-3B-Instruct`
242
+
243
+ ### `NameError: name 'PreTrainedConfig' is not defined` / `check_model_inputs` ImportError
244
+
245
+ Use unsloth≥2025.10.14 (PreTrainedConfig fix) with transformers≥4.57 (check_model_inputs). Run `uv sync --extra train` to get compatible versions.
246
+
247
+ ### `ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'`
248
+
249
+ unsloth_zoo expects TRL &lt;0.20. The project pins `trl>=0.19.0,<0.20`. If you see this error, ensure you've run `uv sync --extra train` so the locked trl version is used. Alternatively, try:
250
+
251
+ ```bash
252
+ pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo
253
+ ```
254
+
255
+ (A newer unsloth_zoo may fix this and allow TRL 0.20+.)
256
+
257
+ ### Unsloth import order warning
258
+
259
+ If you see "Unsloth should be imported before trl, transformers, peft", ensure `training_unsloth` is imported before `training_script` in your notebook:
260
+
261
+ ```python
262
+ from training_unsloth import make_training_args, run_training # first
263
+ import training_script as base
264
+ ```
265
+
266
  ## See Also
267
 
268
  - Main [README.md](README.md) for project overview, APIs, and usage
README.md CHANGED
@@ -384,6 +384,23 @@ uv run python training_unsloth.py --dry-run
384
  uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B
385
  ```
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  Key arguments:
388
 
389
  | Argument | Default | Description |
@@ -405,7 +422,7 @@ Key arguments:
405
 
406
  By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
407
 
408
- `training_unsloth.py` adds H100-oriented options such as `--max-seq-length`, `--disable-4bit`, `--disable-fast-inference`, and LoRA settings (`--lora-r`, `--lora-alpha`, `--lora-dropout`).
409
 
410
  After training, the script saves plots to the output directory:
411
 
 
384
  uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B
385
  ```
386
 
387
+ **Laptop / mid-range GPU (e.g. 12GB VRAM):** Use reduced batch size and sequence length to avoid OOM:
388
+
389
+ ```bash
390
+ uv sync --extra train
391
+ uv pip install unsloth unsloth_zoo --no-deps # if using training_unsloth.py
392
+ uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b --dataset-episodes 12 --rollout-steps 6 --per-device-train-batch-size 1 --num-generations 2 --gradient-accumulation-steps 4 --max-seq-length 1024 --trust-remote-code
393
+ ```
394
+
395
+ If you still hit OOM, try `--max-seq-length 768` or `--num-generations 1`.
396
+
397
+ **PyTorch CUDA:** Use the PyTorch index that matches your GPU. For older cards (RTX 20/30/40 series): `uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`. For **RTX 50 series (Blackwell, sm_120)** you need a CUDA 12.8 build:
398
+
399
+ ```bash
400
+ uv pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
401
+ uv pip install triton-windows # required by Unsloth on Windows
402
+ ```
403
+
404
  Key arguments:
405
 
406
  | Argument | Default | Description |
 
422
 
423
  By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
424
 
425
+ `training_unsloth.py` adds H100-oriented options such as `--max-seq-length`, `--disable-4bit`, and LoRA settings (`--lora-r`, `--lora-alpha`, `--lora-dropout`). vLLM fast inference is disabled to avoid dependency conflicts.
426
 
427
  After training, the script saves plots to the output directory:
428
 
_dashboard_state.json CHANGED
@@ -1,29 +1,29 @@
1
  {
2
- "timestamp": 1772965877.7057643,
3
  "step": 20,
4
  "episode_done": false,
5
- "cumulative_reward": -1.0267290925637838,
6
- "gen_time_s": 28.83,
7
- "model_response_raw": "{\n\"action_type\": \"marker_selection\",\n\"method\": \"\",\n\"parameters\": {\"\"diff_genes\"\": [],\"gene_lists_file\":\"DE_results.h5ad\"},\n\"Justification\": \"The 'violation' error indicates we cannot proceed directly from previous failed attempts due to incomplete filtering/normalization logic implied by the blockage condition ('without prior analysis'). However, looking strictly at the provided completed vs failing log chain where `design` was followed by `[OK] design`, but then immediate failures occurred (`synthesis_blocked`) which were resolved upon finding/deepening insights? Wait, re-evaluating strict fl",
8
  "model_thinking": "",
9
  "thinking_enabled": false,
10
  "task": {
11
- "problem_statement": "Validate candidate biomarker TREM2 for Alzheimer's disease in brain tissue using single-cell RNA sequencing.",
12
  "organism": "human",
13
- "tissue": "brain",
14
  "modality": "scRNA-seq",
15
  "conditions": [
16
  "healthy",
17
- "Alzheimer"
18
  ],
19
- "budget_limit": 115264.15272286453,
20
- "time_limit_days": 111.8768783273548
21
  },
22
  "resources": {
23
  "budget_used": 28900.0,
24
- "budget_remaining": 86364.15,
25
  "time_used_days": 18.0,
26
- "time_remaining_days": 93.9,
27
  "samples_consumed": 0,
28
  "compute_hours_used": 12.0
29
  },
@@ -32,9 +32,9 @@
32
  "step_index": 1,
33
  "action_type": "collect_sample",
34
  "method": null,
35
- "output_summary": "Collected 6 samples (quality=0.57)",
36
  "success": true,
37
- "quality_score": 0.575,
38
  "resource_cost": 5000.0,
39
  "time_cost_days": 7.0
40
  },
@@ -62,9 +62,9 @@
62
  "step_index": 4,
63
  "action_type": "prepare_library",
64
  "method": null,
65
- "output_summary": "Library prepared (complexity=0.77)",
66
  "success": true,
67
- "quality_score": 0.77,
68
  "resource_cost": 8000.0,
69
  "time_cost_days": 3.0
70
  },
@@ -72,42 +72,42 @@
72
  "step_index": 5,
73
  "action_type": "sequence_cells",
74
  "method": null,
75
- "output_summary": "Sequenced 12178 cells, 9490 genes detected, median UMI=2966",
76
  "success": true,
77
- "quality_score": 0.457,
78
  "resource_cost": 15000.0,
79
  "time_cost_days": 5.0
80
  },
81
  {
82
  "step_index": 6,
83
- "action_type": "filter_data",
84
- "method": null,
85
- "output_summary": "Action blocked: Cannot filter data before QC",
86
- "success": false,
87
- "quality_score": 1.0,
88
- "resource_cost": 50.0,
89
- "time_cost_days": 0.2
90
- },
91
- {
92
- "step_index": 7,
93
  "action_type": "run_qc",
94
  "method": null,
95
  "output_summary": "QC metrics computed",
96
  "success": true,
97
- "quality_score": 0.842,
98
  "resource_cost": 100.0,
99
  "time_cost_days": 0.5
100
  },
101
  {
102
- "step_index": 8,
103
  "action_type": "filter_data",
104
  "method": null,
105
- "output_summary": "Filtered 12178 \u2192 9967 cells (82% retained)",
106
  "success": true,
107
- "quality_score": 0.818,
108
  "resource_cost": 50.0,
109
  "time_cost_days": 0.2
110
  },
 
 
 
 
 
 
 
 
 
 
111
  {
112
  "step_index": 9,
113
  "action_type": "normalize_data",
@@ -142,14 +142,24 @@
142
  "step_index": 12,
143
  "action_type": "cluster_cells",
144
  "method": null,
145
- "output_summary": "Found 3 clusters",
146
  "success": true,
147
- "quality_score": 0.765,
148
  "resource_cost": 100.0,
149
  "time_cost_days": 0.5
150
  },
151
  {
152
  "step_index": 13,
 
 
 
 
 
 
 
 
 
 
153
  "action_type": "design_followup_experiment",
154
  "method": null,
155
  "output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
@@ -159,7 +169,7 @@
159
  "time_cost_days": 0.5
160
  },
161
  {
162
- "step_index": 14,
163
  "action_type": "synthesize_conclusion",
164
  "method": null,
165
  "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
@@ -168,16 +178,6 @@
168
  "resource_cost": 0.0,
169
  "time_cost_days": 0.5
170
  },
171
- {
172
- "step_index": 15,
173
- "action_type": "marker_selection",
174
- "method": null,
175
- "output_summary": "Action blocked: Cannot select markers without DE results",
176
- "success": false,
177
- "quality_score": 1.0,
178
- "resource_cost": 100.0,
179
- "time_cost_days": 0.5
180
- },
181
  {
182
  "step_index": 16,
183
  "action_type": "synthesize_conclusion",
@@ -222,8 +222,8 @@
222
  "Cannot synthesise conclusion before inferring pathways or mechanisms"
223
  ],
224
  "uncertainty_summary": {
225
- "avg_uncertainty": 0.194,
226
- "avg_quality": 0.809
227
  },
228
  "reward_breakdown": {
229
  "validity": -1.0,
@@ -249,72 +249,92 @@
249
  "latent": {
250
  "cell_populations": [
251
  {
252
- "name": "excitatory_neuron",
253
- "proportion": 0.425,
254
  "marker_genes": [
255
- "SLC17A7",
256
- "CAMK2A",
257
- "NRGN"
258
  ],
259
- "state": "stressed"
260
  },
261
  {
262
- "name": "inhibitory_neuron",
263
- "proportion": 0.346,
264
  "marker_genes": [
265
- "GAD1",
266
- "GAD2",
267
- "SLC32A1"
268
  ],
269
  "state": "normal"
270
  },
271
  {
272
- "name": "OPC",
273
- "proportion": 0.093,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  "marker_genes": [
275
- "PDGFRA",
276
- "CSPG4",
277
- "OLIG2"
278
  ],
279
- "state": "progenitor"
280
  },
281
  {
282
  "name": "endothelial",
283
- "proportion": 0.137,
284
  "marker_genes": [
285
- "CLDN5",
286
- "FLT1",
287
- "PECAM1"
288
  ],
289
  "state": "quiescent"
290
  }
291
  ],
292
  "true_markers": [
293
- "TREM2",
294
- "APOE",
295
- "GFAP"
 
296
  ],
297
  "causal_mechanisms": [
298
- "TREM2-mediated microglial activation in amyloid clearance",
299
- "complement-driven synaptic pruning",
300
- "reactive astrogliosis amplifying neuroinflammation"
301
  ],
302
  "true_pathways": {
303
- "complement_cascade": 0.827,
304
- "neuroinflammation": 0.666,
305
- "amyloid_processing": 0.733,
306
- "synaptic_signalling": 0.438,
307
- "lipid_metabolism": 0.616
308
  },
309
- "true_de_genes_count": 10,
310
- "true_regulatory_network_size": 0,
311
  "confounders": {},
312
- "n_true_cells": 20321,
313
  "technical": {
314
- "ambient_rna_fraction": 0.050723618495539344,
315
- "doublet_rate": 0.0546771548836933,
316
- "dropout_rate": 0.05122168063297322,
317
- "sample_quality": 0.937985596833521,
318
  "library_complexity": 0.8,
319
  "capture_efficiency": 0.6
320
  },
@@ -339,9 +359,9 @@
339
  "followup_designed": true,
340
  "subagent_review_requested": false,
341
  "conclusion_reached": false,
342
- "n_cells_sequenced": 12178,
343
- "n_cells_after_filter": 9967,
344
- "n_clusters_found": "3",
345
  "n_de_genes_found": null,
346
  "n_markers_found": null
347
  },
 
1
  {
2
+ "timestamp": 1772986999.4539518,
3
  "step": 20,
4
  "episode_done": false,
5
+ "cumulative_reward": -1.2923024829860532,
6
+ "gen_time_s": 17.14,
7
+ "model_response_raw": "{\n\"action_type\": \"design_followup_experiment\",\n\"method\": \"\",\n\"parameters\": {\n\"candidate_target_genes\": [],\n\"suggested_wet_lab_targets\": []\n},\n\"Justification\": \"The previous attempt failed due to incomplete genomic profiling required for robust marker identification via CRISPR/Cas9 editing followed by functional readout. Without identifying specific high-confidence target genes within the identified clusters through comprehensive transcriptomic characterization first, designing a focused knockdown screen lacks the necessary molecular basis.\",\n\"Confidence\": 0.85\n}",
8
  "model_thinking": "",
9
  "thinking_enabled": false,
10
  "task": {
11
+ "problem_statement": "Validate candidate biomarker ERBB2 for breast cancer in breast tissue using single-cell RNA sequencing.",
12
  "organism": "human",
13
+ "tissue": "breast",
14
  "modality": "scRNA-seq",
15
  "conditions": [
16
  "healthy",
17
+ "breast_cancer"
18
  ],
19
+ "budget_limit": 89988.7088221059,
20
+ "time_limit_days": 160.54122628462622
21
  },
22
  "resources": {
23
  "budget_used": 28900.0,
24
+ "budget_remaining": 61088.71,
25
  "time_used_days": 18.0,
26
+ "time_remaining_days": 142.5,
27
  "samples_consumed": 0,
28
  "compute_hours_used": 12.0
29
  },
 
32
  "step_index": 1,
33
  "action_type": "collect_sample",
34
  "method": null,
35
+ "output_summary": "Collected 6 samples (quality=0.42)",
36
  "success": true,
37
+ "quality_score": 0.422,
38
  "resource_cost": 5000.0,
39
  "time_cost_days": 7.0
40
  },
 
62
  "step_index": 4,
63
  "action_type": "prepare_library",
64
  "method": null,
65
+ "output_summary": "Library prepared (complexity=0.59)",
66
  "success": true,
67
+ "quality_score": 0.595,
68
  "resource_cost": 8000.0,
69
  "time_cost_days": 3.0
70
  },
 
72
  "step_index": 5,
73
  "action_type": "sequence_cells",
74
  "method": null,
75
+ "output_summary": "Sequenced 8578 cells, 9244 genes detected, median UMI=3090",
76
  "success": true,
77
+ "quality_score": 0.359,
78
  "resource_cost": 15000.0,
79
  "time_cost_days": 5.0
80
  },
81
  {
82
  "step_index": 6,
 
 
 
 
 
 
 
 
 
 
83
  "action_type": "run_qc",
84
  "method": null,
85
  "output_summary": "QC metrics computed",
86
  "success": true,
87
+ "quality_score": 0.821,
88
  "resource_cost": 100.0,
89
  "time_cost_days": 0.5
90
  },
91
  {
92
+ "step_index": 7,
93
  "action_type": "filter_data",
94
  "method": null,
95
+ "output_summary": "Filtered 8578 \u2192 7050 cells (82% retained)",
96
  "success": true,
97
+ "quality_score": 0.822,
98
  "resource_cost": 50.0,
99
  "time_cost_days": 0.2
100
  },
101
+ {
102
+ "step_index": 8,
103
+ "action_type": "cluster_cells",
104
+ "method": null,
105
+ "output_summary": "Action blocked: Cannot cluster before normalisation",
106
+ "success": false,
107
+ "quality_score": 1.0,
108
+ "resource_cost": 100.0,
109
+ "time_cost_days": 0.5
110
+ },
111
  {
112
  "step_index": 9,
113
  "action_type": "normalize_data",
 
142
  "step_index": 12,
143
  "action_type": "cluster_cells",
144
  "method": null,
145
+ "output_summary": "Found 7 clusters",
146
  "success": true,
147
+ "quality_score": 0.741,
148
  "resource_cost": 100.0,
149
  "time_cost_days": 0.5
150
  },
151
  {
152
  "step_index": 13,
153
+ "action_type": "marker_selection",
154
+ "method": null,
155
+ "output_summary": "Action blocked: Cannot select markers without DE results",
156
+ "success": false,
157
+ "quality_score": 1.0,
158
+ "resource_cost": 100.0,
159
+ "time_cost_days": 0.5
160
+ },
161
+ {
162
+ "step_index": 14,
163
  "action_type": "design_followup_experiment",
164
  "method": null,
165
  "output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
 
169
  "time_cost_days": 0.5
170
  },
171
  {
172
+ "step_index": 15,
173
  "action_type": "synthesize_conclusion",
174
  "method": null,
175
  "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
 
178
  "resource_cost": 0.0,
179
  "time_cost_days": 0.5
180
  },
 
 
 
 
 
 
 
 
 
 
181
  {
182
  "step_index": 16,
183
  "action_type": "synthesize_conclusion",
 
222
  "Cannot synthesise conclusion before inferring pathways or mechanisms"
223
  ],
224
  "uncertainty_summary": {
225
+ "avg_uncertainty": 0.177,
226
+ "avg_quality": 0.804
227
  },
228
  "reward_breakdown": {
229
  "validity": -1.0,
 
249
  "latent": {
250
  "cell_populations": [
251
  {
252
+ "name": "luminal_epithelial",
253
+ "proportion": 0.433,
254
  "marker_genes": [
255
+ "KRT8",
256
+ "KRT18",
257
+ "EPCAM"
258
  ],
259
+ "state": "normal"
260
  },
261
  {
262
+ "name": "basal_epithelial",
263
+ "proportion": 0.157,
264
  "marker_genes": [
265
+ "KRT14",
266
+ "KRT5",
267
+ "TP63"
268
  ],
269
  "state": "normal"
270
  },
271
  {
272
+ "name": "fibroblast",
273
+ "proportion": 0.119,
274
+ "marker_genes": [
275
+ "COL1A1",
276
+ "COL3A1",
277
+ "FAP"
278
+ ],
279
+ "state": "quiescent"
280
+ },
281
+ {
282
+ "name": "T_cell",
283
+ "proportion": 0.105,
284
+ "marker_genes": [
285
+ "CD3D",
286
+ "CD3E",
287
+ "CD8A"
288
+ ],
289
+ "state": "activated"
290
+ },
291
+ {
292
+ "name": "macrophage",
293
+ "proportion": 0.096,
294
  "marker_genes": [
295
+ "CD68",
296
+ "CD163",
297
+ "CSF1R"
298
  ],
299
+ "state": "inflammatory"
300
  },
301
  {
302
  "name": "endothelial",
303
+ "proportion": 0.09,
304
  "marker_genes": [
305
+ "PECAM1",
306
+ "VWF",
307
+ "CDH5"
308
  ],
309
  "state": "quiescent"
310
  }
311
  ],
312
  "true_markers": [
313
+ "ERBB2",
314
+ "MKI67",
315
+ "CD274",
316
+ "VIM"
317
  ],
318
  "causal_mechanisms": [
319
+ "ERBB2-driven proliferative signalling",
320
+ "immune evasion via PD-L1 upregulation"
 
321
  ],
322
  "true_pathways": {
323
+ "cell_cycle": 0.889,
324
+ "PI3K_AKT_signalling": 0.803,
325
+ "EMT": 0.757,
326
+ "immune_checkpoint": 0.579,
327
+ "estrogen_signalling": 0.644
328
  },
329
+ "true_de_genes_count": 9,
330
+ "true_regulatory_network_size": 9,
331
  "confounders": {},
332
+ "n_true_cells": 14493,
333
  "technical": {
334
+ "ambient_rna_fraction": 0.05700028722692205,
335
+ "doublet_rate": 0.0716382392677839,
336
+ "dropout_rate": 0.1822398381996976,
337
+ "sample_quality": 0.7058144963381642,
338
  "library_complexity": 0.8,
339
  "capture_efficiency": 0.6
340
  },
 
359
  "followup_designed": true,
360
  "subagent_review_requested": false,
361
  "conclusion_reached": false,
362
+ "n_cells_sequenced": 8578,
363
+ "n_cells_after_filter": 7050,
364
+ "n_clusters_found": "7",
365
  "n_de_genes_found": null,
366
  "n_markers_found": null
367
  },
colab_train_unsloth.ipynb ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Unsloth training on Colab\n",
8
+ "\n",
9
+ "Minimal setup: clone repo → install deps → run GRPO training with Unsloth (Qwen3-4B, 4-bit + LoRA).\n",
10
+ "\n",
11
+ "**Runtime**: Enable a GPU (e.g. T4) in Colab: Runtime → Change runtime type → GPU."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "# 1. Clone repo (set branch/tag if needed)\n",
21
+ "REPO_URL = \"https://github.com/mhtruong1031/OpenENV-Hackathon.git\" # or your fork\n",
22
+ "REPO_DIR = \"OpenENV-Hackathon\"\n",
23
+ "\n",
24
+ "!git clone --depth 1 {REPO_URL} {REPO_DIR}\n",
25
+ "%cd {REPO_DIR}"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "# 2. Install requirements: project + train extras + Unsloth (no-deps to keep trl>=0.29)\n",
35
+ "!pip install -q -e \".[train]\"\n",
36
+ "!pip install -q unsloth unsloth_zoo --no-deps\n",
37
+ "\n",
38
+ "# Optional: reward backends\n",
39
+ "!pip install -q sentence-transformers gseapy 2>/dev/null || true"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "# 3. Unsloth must be imported before trl/transformers/peft\n",
49
+ "import unsloth # noqa: F401\n",
50
+ "import torch\n",
51
+ "from pathlib import Path\n",
52
+ "\n",
53
+ "from training_unsloth import make_training_args, run_training\n",
54
+ "\n",
55
+ "print(\"CUDA:\", torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"\")\n",
56
+ "Path(\"artifacts\").mkdir(exist_ok=True)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# 4. Training config (small run for Colab T4)\n",
66
+ "args = make_training_args(\n",
67
+ " model_id=\"Qwen/Qwen3-4B-Base\",\n",
68
+ " output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n",
69
+ " dataset_episodes=16,\n",
70
+ " rollout_steps=10,\n",
71
+ " collection_policy=\"heuristic\",\n",
72
+ " reward_backend=\"local\",\n",
73
+ " domain_randomise=True,\n",
74
+ " num_generations=4,\n",
75
+ " max_completion_length=160,\n",
76
+ " max_prompt_length=1280,\n",
77
+ " max_seq_length=2048,\n",
78
+ " per_device_train_batch_size=2,\n",
79
+ " gradient_accumulation_steps=4,\n",
80
+ " learning_rate=5e-6,\n",
81
+ " num_train_epochs=1.0,\n",
82
+ " logging_steps=1,\n",
83
+ " save_steps=25,\n",
84
+ " trust_remote_code=True,\n",
85
+ " dry_run=False,\n",
86
+ " seed=42,\n",
87
+ ")\n",
88
+ "args"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# 5. Run training\n",
98
+ "result = run_training(args)\n",
99
+ "print(\"Plots:\", result[\"plot_paths\"])"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "# 6. (Optional) Show loss curves\n",
109
+ "from IPython.display import Image, display\n",
110
+ "for name, path in result[\"plot_paths\"].items():\n",
111
+ " display(Image(filename=path))"
112
+ ]
113
+ }
114
+ ],
115
+ "metadata": {
116
+ "kernelspec": {
117
+ "display_name": "Python 3",
118
+ "language": "python",
119
+ "name": "python3"
120
+ },
121
+ "language_info": {
122
+ "name": "python",
123
+ "version": "3.10.0"
124
+ }
125
+ },
126
+ "nbformat": 4,
127
+ "nbformat_minor": 4
128
+ }
inference.ipynb CHANGED
@@ -2,27 +2,30 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
 
5
  "metadata": {},
6
  "source": [
7
  "# Self-Driving Lab Inference on H100 With Unsloth\n",
8
  "\n",
9
  "This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100."
10
- ],
11
- "id": "a9d34036"
12
  },
13
  {
14
  "cell_type": "code",
 
 
15
  "metadata": {},
 
16
  "source": [
17
  "%pip install -q -U torch transformers unsloth"
18
- ],
19
- "execution_count": null,
20
- "outputs": [],
21
- "id": "20b36e01"
22
  },
23
  {
24
  "cell_type": "code",
 
 
25
  "metadata": {},
 
26
  "source": [
27
  "import json\n",
28
  "\n",
@@ -36,14 +39,14 @@
36
  "if torch.cuda.is_available():\n",
37
  " print(\"GPU:\", torch.cuda.get_device_name(0))\n",
38
  " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())"
39
- ],
40
- "execution_count": null,
41
- "outputs": [],
42
- "id": "bcf24a2e"
43
  },
44
  {
45
  "cell_type": "code",
 
 
46
  "metadata": {},
 
47
  "source": [
48
  "MODEL_PATH = \"artifacts/grpo-unsloth-output\" # or a Hugging Face repo / base model id\n",
49
  "SCENARIO_NAME = \"cardiac_disease_de\"\n",
@@ -54,21 +57,20 @@
54
  " trust_remote_code=True,\n",
55
  " max_seq_length=2048,\n",
56
  " load_in_4bit=True,\n",
57
- " fast_inference=True,\n",
58
  " prepare_for_inference=True,\n",
59
  ")\n",
60
  "\n",
61
  "env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n",
62
  "obs = env.reset(seed=SEED)\n",
63
  "print(format_observation(obs)[:3000])"
64
- ],
65
- "execution_count": null,
66
- "outputs": [],
67
- "id": "c54f2cfd"
68
  },
69
  {
70
  "cell_type": "code",
 
 
71
  "metadata": {},
 
72
  "source": [
73
  "result = generate_action_with_model(\n",
74
  " model,\n",
@@ -84,14 +86,14 @@
84
  "print(result[\"response_text\"])\n",
85
  "print(\"\\nParsed action:\\n\")\n",
86
  "result[\"action\"].model_dump() if result[\"action\"] is not None else None"
87
- ],
88
- "execution_count": null,
89
- "outputs": [],
90
- "id": "f9b25208"
91
  },
92
  {
93
  "cell_type": "code",
 
 
94
  "metadata": {},
 
95
  "source": [
96
  "if result[\"action\"] is not None:\n",
97
  " next_obs = env.step(result[\"action\"])\n",
@@ -106,14 +108,14 @@
106
  " print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n",
107
  "else:\n",
108
  " print(\"Model output did not parse into an ExperimentAction.\")"
109
- ],
110
- "execution_count": null,
111
- "outputs": [],
112
- "id": "c2408f52"
113
  },
114
  {
115
  "cell_type": "code",
 
 
116
  "metadata": {},
 
117
  "source": [
118
  "# Optional short closed-loop rollout.\n",
119
  "obs = env.reset(seed=7)\n",
@@ -145,10 +147,7 @@
145
  " break\n",
146
  "\n",
147
  "trajectory"
148
- ],
149
- "execution_count": null,
150
- "outputs": [],
151
- "id": "8af34f32"
152
  }
153
  ],
154
  "metadata": {
@@ -163,4 +162,4 @@
163
  },
164
  "nbformat": 4,
165
  "nbformat_minor": 5
166
- }
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
+ "id": "a9d34036",
6
  "metadata": {},
7
  "source": [
8
  "# Self-Driving Lab Inference on H100 With Unsloth\n",
9
  "\n",
10
  "This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100."
11
+ ]
 
12
  },
13
  {
14
  "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "20b36e01",
17
  "metadata": {},
18
+ "outputs": [],
19
  "source": [
20
  "%pip install -q -U torch transformers unsloth"
21
+ ]
 
 
 
22
  },
23
  {
24
  "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "bcf24a2e",
27
  "metadata": {},
28
+ "outputs": [],
29
  "source": [
30
  "import json\n",
31
  "\n",
 
39
  "if torch.cuda.is_available():\n",
40
  " print(\"GPU:\", torch.cuda.get_device_name(0))\n",
41
  " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())"
42
+ ]
 
 
 
43
  },
44
  {
45
  "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "c54f2cfd",
48
  "metadata": {},
49
+ "outputs": [],
50
  "source": [
51
  "MODEL_PATH = \"artifacts/grpo-unsloth-output\" # or a Hugging Face repo / base model id\n",
52
  "SCENARIO_NAME = \"cardiac_disease_de\"\n",
 
57
  " trust_remote_code=True,\n",
58
  " max_seq_length=2048,\n",
59
  " load_in_4bit=True,\n",
 
60
  " prepare_for_inference=True,\n",
61
  ")\n",
62
  "\n",
63
  "env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n",
64
  "obs = env.reset(seed=SEED)\n",
65
  "print(format_observation(obs)[:3000])"
66
+ ]
 
 
 
67
  },
68
  {
69
  "cell_type": "code",
70
+ "execution_count": null,
71
+ "id": "f9b25208",
72
  "metadata": {},
73
+ "outputs": [],
74
  "source": [
75
  "result = generate_action_with_model(\n",
76
  " model,\n",
 
86
  "print(result[\"response_text\"])\n",
87
  "print(\"\\nParsed action:\\n\")\n",
88
  "result[\"action\"].model_dump() if result[\"action\"] is not None else None"
89
+ ]
 
 
 
90
  },
91
  {
92
  "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "c2408f52",
95
  "metadata": {},
96
+ "outputs": [],
97
  "source": [
98
  "if result[\"action\"] is not None:\n",
99
  " next_obs = env.step(result[\"action\"])\n",
 
108
  " print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n",
109
  "else:\n",
110
  " print(\"Model output did not parse into an ExperimentAction.\")"
111
+ ]
 
 
 
112
  },
113
  {
114
  "cell_type": "code",
115
+ "execution_count": null,
116
+ "id": "8af34f32",
117
  "metadata": {},
118
+ "outputs": [],
119
  "source": [
120
  "# Optional short closed-loop rollout.\n",
121
  "obs = env.reset(seed=7)\n",
 
147
  " break\n",
148
  "\n",
149
  "trajectory"
150
+ ]
 
 
 
151
  }
152
  ],
153
  "metadata": {
 
162
  },
163
  "nbformat": 4,
164
  "nbformat_minor": 5
165
+ }
pyproject.toml CHANGED
@@ -12,7 +12,7 @@ build-backend = "setuptools.build_meta"
12
  name = "openenv-bio-experiment"
13
  version = "0.1.0"
14
  description = "RL environment for biological experiment pipeline planning"
15
- requires-python = ">=3.10"
16
  dependencies = [
17
  "openenv-core[core]>=0.2.0",
18
  "numpy>=1.24.0",
@@ -34,12 +34,15 @@ train = [
34
  "accelerate>=1.13.0",
35
  "bitsandbytes>=0.45.0",
36
  "datasets>=4.6.1",
 
37
  "matplotlib>=3.10.8",
38
  "peft>=0.15.0",
39
  "torch>=2.10.0",
40
- "transformers>=5.3.0",
41
- "trl>=0.29.0",
42
- "unsloth",
 
 
43
  ]
44
 
45
  [project.scripts]
 
12
  name = "openenv-bio-experiment"
13
  version = "0.1.0"
14
  description = "RL environment for biological experiment pipeline planning"
15
+ requires-python = ">=3.10,<3.13"
16
  dependencies = [
17
  "openenv-core[core]>=0.2.0",
18
  "numpy>=1.24.0",
 
34
  "accelerate>=1.13.0",
35
  "bitsandbytes>=0.45.0",
36
  "datasets>=4.6.1",
37
+ "ipykernel>=6.0.0",
38
  "matplotlib>=3.10.8",
39
  "peft>=0.15.0",
40
  "torch>=2.10.0",
41
+ "torchvision>=0.20.0", # required by transformers for Qwen3.5 (image_utils)
42
+ "transformers>=5.3.0", # 5.3+ required for Qwen3.5 (qwen3_5 model type)
43
+ "llm-blender>=0.0.2", # required by trl GRPOTrainer judges
44
+ "mergekit>=0.1.0", # required by trl GRPOTrainer/callbacks
45
+ "trl>=0.29.0", # GRPOTrainer; 0.29+ compatible with transformers 5.3
46
  ]
47
 
48
  [project.scripts]
requirements-unsloth.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unsloth path: install after uv sync --extra train
2
+ # Unsloth pins trl<=0.24; this project uses trl>=0.29 for training_script.py.
3
+ # Use --no-deps to avoid downgrading trl (Unsloth works with TRL 0.29 in practice).
4
+ #
5
+ # Install:
6
+ # uv sync --extra train
7
+ # uv pip install unsloth unsloth_zoo --no-deps
8
+ #
9
+ # Run training with Qwen3-4B-Base:
10
+ # uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b
11
+ #
12
+ unsloth>=2025.10.14
13
+ unsloth_zoo
run_agent_unsloth.py CHANGED
@@ -40,7 +40,7 @@ LOAD_IN_4BIT = (
40
  not in {"0", "false", "off"}
41
  )
42
  FAST_INFERENCE = (
43
- os.getenv("RUN_AGENT_UNSLOTH_FAST_INFERENCE", "1").strip().lower()
44
  not in {"0", "false", "off"}
45
  )
46
 
 
40
  not in {"0", "false", "off"}
41
  )
42
  FAST_INFERENCE = (
43
+ os.getenv("RUN_AGENT_UNSLOTH_FAST_INFERENCE", "0").strip().lower()
44
  not in {"0", "false", "off"}
45
  )
46
 
scripts/register_kernel_312.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Register a Python 3.12 Jupyter kernel from the project's uv venv.
3
+ # Run from project root: bash scripts/register_kernel_312.sh
4
+
5
+ set -e
6
+ cd "$(dirname "$0")/.."
7
+
8
+ echo "Ensuring Python 3.12 venv..."
9
+ uv python install 3.12
10
+ UV_TORCH_BACKEND="${UV_TORCH_BACKEND:-cu128}" uv sync --extra train
11
+
12
+ echo "Registering Jupyter kernel 'OpenEnv Bio (Python 3.12)'..."
13
+ uv run python -m ipykernel install --user \
14
+ --name openenv-bio-312 \
15
+ --display-name "OpenEnv Bio (Python 3.12)"
16
+
17
+ echo "Done. Select 'OpenEnv Bio (Python 3.12)' in the notebook kernel picker."
train.ipynb CHANGED
@@ -7,8 +7,9 @@
7
  "source": [
8
  "# Train A Self-Driving Lab Policy on H100\n",
9
  "\n",
10
- "This notebook is designed for Jupyter GPU nodes such as H100 clusters.\n",
11
- "It uses the notebook-friendly helpers in `training_script.py` to build prompts from the same self-driving lab environment state used by `run_agent.py`, preview reference actions, and launch GRPO training without shelling out to the CLI."
 
12
  ]
13
  },
14
  {
@@ -32,10 +33,15 @@
32
  "outputs": [],
33
  "source": [
34
  "from pathlib import Path\n",
 
35
  "\n",
36
  "import torch\n",
 
37
  "\n",
38
- "from training_script import build_prompt_examples, make_training_args, run_training\n",
 
 
 
39
  "\n",
40
  "print(\"CUDA available:\", torch.cuda.is_available())\n",
41
  "if torch.cuda.is_available():\n",
@@ -53,20 +59,23 @@
53
  "outputs": [],
54
  "source": [
55
  "args = make_training_args(\n",
56
- " model_id=\"Qwen/Qwen3.5-0.8B\",\n",
57
  " output_dir=\"artifacts/grpo-h100\",\n",
58
- " dataset_episodes=32,\n",
59
- " rollout_steps=10,\n",
60
  " collection_policy=\"heuristic\",\n",
61
  " reward_backend=\"local\",\n",
62
  " domain_randomise=True,\n",
63
- " num_generations=4,\n",
64
- " max_completion_length=160,\n",
65
- " max_prompt_length=1280,\n",
66
- " per_device_train_batch_size=4,\n",
67
- " gradient_accumulation_steps=4,\n",
68
- " learning_rate=5e-6,\n",
 
 
69
  " num_train_epochs=1.0,\n",
 
70
  " logging_steps=1,\n",
71
  " save_steps=25,\n",
72
  " trust_remote_code=True,\n",
@@ -84,14 +93,18 @@
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
87
- "preview_examples = build_prompt_examples(\n",
88
- " dataset_episodes=1,\n",
89
- " rollout_steps=args.rollout_steps,\n",
90
- " collection_policy=args.collection_policy,\n",
91
- " scenario_names=[\"cardiac_disease_de\"],\n",
92
- " seed=args.seed,\n",
93
- " domain_randomise=args.domain_randomise,\n",
 
 
 
94
  ")\n",
 
95
  "\n",
96
  "print(preview_examples[0][\"prompt\"][:3500])\n",
97
  "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
 
7
  "source": [
8
  "# Train A Self-Driving Lab Policy on H100\n",
9
  "\n",
10
+ "This notebook trains a GRPO policy for the **same bio-experiment planning task** as `run_agent.py`: choosing structured actions (collect_sample, run_qc, cluster, de_analysis, etc.) step-by-step in the OpenEnv bio-experiment environment.\n",
11
+ "\n",
12
+ "**Flow:** Build prompts from `BioExperimentEnvironment` rollouts (same env `run_agent.py` uses) → OpenEnv reward scores actions locally → GRPO trains the model. Uses `build_openenv_reward`, `prepare_prompt_examples`, and `build_grpo_trainer` from `training_script.py`."
13
  ]
14
  },
15
  {
 
33
  "outputs": [],
34
  "source": [
35
  "from pathlib import Path\n",
36
+ "import importlib\n",
37
  "\n",
38
  "import torch\n",
39
+ "import training_script as training_script_module\n",
40
  "\n",
41
+ "training_script_module = importlib.reload(training_script_module)\n",
42
+ "make_training_args = training_script_module.make_training_args\n",
43
+ "prepare_prompt_examples = training_script_module.prepare_prompt_examples\n",
44
+ "run_training = training_script_module.run_training\n",
45
  "\n",
46
  "print(\"CUDA available:\", torch.cuda.is_available())\n",
47
  "if torch.cuda.is_available():\n",
 
59
  "outputs": [],
60
  "source": [
61
  "args = make_training_args(\n",
62
+ " model_id=\"Qwen/Qwen3.5-9B\",\n",
63
  " output_dir=\"artifacts/grpo-h100\",\n",
64
+ " dataset_episodes=64, # more data per run\n",
65
+ " rollout_steps=12, # slightly longer trajectories\n",
66
  " collection_policy=\"heuristic\",\n",
67
  " reward_backend=\"local\",\n",
68
  " domain_randomise=True,\n",
69
+ "\n",
70
+ " num_generations=8, # H100 can handle a larger GRPO group\n",
71
+ " max_completion_length=192, # small bump if completions are being cut off\n",
72
+ " max_prompt_length=1024, # trim a bit unless you truly need 1280\n",
73
+ "\n",
74
+ " per_device_train_batch_size=8, # first thing to try on H100\n",
75
+ " gradient_accumulation_steps=2, # same effective batch as before, fewer sync steps\n",
76
+ " learning_rate=1e-5, # slightly more aggressive for LoRA/QLoRA-style RL tuning\n",
77
  " num_train_epochs=1.0,\n",
78
+ "\n",
79
  " logging_steps=1,\n",
80
  " save_steps=25,\n",
81
  " trust_remote_code=True,\n",
 
93
  "metadata": {},
94
  "outputs": [],
95
  "source": [
96
+ "# Same prompt format run_agent.py sees: SYSTEM_PROMPT + observation\n",
97
+ "preview_data = prepare_prompt_examples(\n",
98
+ " make_training_args(\n",
99
+ " dataset_episodes=1,\n",
100
+ " rollout_steps=args.rollout_steps,\n",
101
+ " collection_policy=args.collection_policy,\n",
102
+ " scenario_name=[\"cardiac_disease_de\"],\n",
103
+ " seed=args.seed,\n",
104
+ " domain_randomise=args.domain_randomise,\n",
105
+ " )\n",
106
  ")\n",
107
+ "preview_examples = preview_data[\"examples\"]\n",
108
  "\n",
109
  "print(preview_examples[0][\"prompt\"][:3500])\n",
110
  "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
train_unsloth.ipynb ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Train a Self-Driving Lab Policy with Unsloth\n",
8
+ "\n",
9
+ "This notebook uses **Unsloth** for fast quantized training on GPU nodes (e.g. H100). It mirrors `train.ipynb` but loads the model via Unsloth's optimized path with 4-bit quantization and LoRA adapters.\n",
10
+ "\n",
11
+ "**Model**: Uses **Qwen3-4B-Base** by default. Alternatives:\n",
12
+ "- `Qwen/Qwen3-4B-Base` (base, no chat template)\n",
13
+ "- `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`\n",
14
+ "- `unsloth/Qwen2.5-7B-Instruct-bnb-4bit`"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Install Unsloth and training dependencies (run once per session)\n",
24
+ "# Option A: uv (if using uv-managed venv)\n",
25
+ "# !uv sync --extra train\n",
26
+ "# !uv pip install unsloth unsloth_zoo --no-deps\n",
27
+ "\n",
28
+ "# Option B: pip\n",
29
+ "%pip install -q -U torch transformers datasets trl accelerate bitsandbytes unsloth unsloth_zoo matplotlib huggingface_hub\n",
30
+ "\n",
31
+ "# Optional extras used by some reward-scoring paths.\n",
32
+ "%pip install -q -U sentence-transformers gseapy"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "# Unsloth must be imported before trl, transformers, peft\n",
42
+ "import unsloth # noqa: F401\n",
43
+ "\n",
44
+ "from pathlib import Path\n",
45
+ "import torch\n",
46
+ "\n",
47
+ "from training_unsloth import make_training_args, run_training\n",
48
+ "import training_script as base\n",
49
+ "\n",
50
+ "print(\"CUDA available:\", torch.cuda.is_available())\n",
51
+ "if torch.cuda.is_available():\n",
52
+ " print(\"GPU:\", torch.cuda.get_device_name(0))\n",
53
+ " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
54
+ "\n",
55
+ "Path(\"artifacts\").mkdir(exist_ok=True)"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "args = make_training_args(\n",
65
+ " model_id=\"Qwen/Qwen3-4B-Base\",\n",
66
+ " output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n",
67
+ " dataset_episodes=32,\n",
68
+ " rollout_steps=10,\n",
69
+ " collection_policy=\"heuristic\",\n",
70
+ " reward_backend=\"local\",\n",
71
+ " domain_randomise=True,\n",
72
+ " num_generations=4,\n",
73
+ " max_completion_length=160,\n",
74
+ " max_prompt_length=1280,\n",
75
+ " max_seq_length=2048,\n",
76
+ " per_device_train_batch_size=4,\n",
77
+ " gradient_accumulation_steps=4,\n",
78
+ " learning_rate=5e-6,\n",
79
+ " num_train_epochs=1.0,\n",
80
+ " logging_steps=1,\n",
81
+ " save_steps=25,\n",
82
+ " trust_remote_code=True,\n",
83
+ " dry_run=False,\n",
84
+ " seed=42,\n",
85
+ ")\n",
86
+ "\n",
87
+ "args"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "preview_examples = base.build_prompt_examples(\n",
97
+ " dataset_episodes=1,\n",
98
+ " rollout_steps=args.rollout_steps,\n",
99
+ " collection_policy=args.collection_policy,\n",
100
+ " scenario_names=[\"cardiac_disease_de\"],\n",
101
+ " seed=args.seed,\n",
102
+ " domain_randomise=args.domain_randomise,\n",
103
+ ")\n",
104
+ "\n",
105
+ "print(preview_examples[0][\"prompt\"][:3500])\n",
106
+ "print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "# Optional smoke test before a full run.\n",
116
+ "dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
117
+ "dry_run_result = run_training(dry_run_args)\n",
118
+ "len(dry_run_result[\"examples\"])"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "from IPython.display import Image, display\n",
128
+ "\n",
129
+ "train_result = run_training(args)\n",
130
+ "for name, plot_path in train_result[\"plot_paths\"].items():\n",
131
+ " print(name, plot_path)\n",
132
+ " display(Image(filename=plot_path))"
133
+ ]
134
+ }
135
+ ],
136
+ "metadata": {
137
+ "kernelspec": {
138
+ "display_name": "Python 3",
139
+ "language": "python",
140
+ "name": "python3"
141
+ },
142
+ "language_info": {
143
+ "name": "python",
144
+ "version": "3.10.0"
145
+ }
146
+ },
147
+ "nbformat": 4,
148
+ "nbformat_minor": 4
149
+ }
training_script.py CHANGED
@@ -1266,11 +1266,122 @@ def resolve_torch_runtime() -> Dict[str, Any]:
1266
  }
1267
 
1268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1269
  def load_model_artifacts(
1270
  model_id: str,
1271
  *,
1272
  trust_remote_code: bool,
1273
  ):
 
1274
  from transformers import AutoModelForCausalLM, AutoTokenizer
1275
 
1276
  runtime = resolve_torch_runtime()
@@ -1295,6 +1406,109 @@ def load_model_artifacts(
1295
  return tokenizer, model
1296
 
1297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1298
  def generate_action_with_model(
1299
  model: Any,
1300
  tokenizer: Any,
@@ -1363,20 +1577,10 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
1363
  "model": model,
1364
  }
1365
 
1366
- scenario_names = selected_scenarios(args.scenario_name)
1367
- examples = build_prompt_examples(
1368
- dataset_episodes=args.dataset_episodes,
1369
- rollout_steps=args.rollout_steps,
1370
- collection_policy=args.collection_policy,
1371
- scenario_names=scenario_names,
1372
- seed=args.seed,
1373
- domain_randomise=args.domain_randomise,
1374
- )
1375
- reward_fn = OpenEnvReward(
1376
- reward_backend=args.reward_backend,
1377
- base_url=args.base_url,
1378
- domain_randomise=args.domain_randomise,
1379
- )
1380
 
1381
  if args.dry_run:
1382
  run_dry_run_preview(examples, reward_fn, args.output_dir)
@@ -1389,41 +1593,30 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
1389
  }
1390
 
1391
  from datasets import Dataset
1392
- from trl import GRPOConfig, GRPOTrainer
1393
-
1394
  train_dataset = Dataset.from_list(examples)
1395
  tokenizer, model = load_model_artifacts(
1396
  args.model_id,
1397
  trust_remote_code=args.trust_remote_code,
1398
  )
1399
- config = GRPOConfig(
1400
- output_dir=args.output_dir,
1401
- learning_rate=args.learning_rate,
1402
- per_device_train_batch_size=args.per_device_train_batch_size,
1403
- gradient_accumulation_steps=args.gradient_accumulation_steps,
1404
- num_generations=args.num_generations,
1405
- max_completion_length=args.max_completion_length,
1406
- num_train_epochs=args.num_train_epochs,
1407
- logging_steps=args.logging_steps,
1408
- save_steps=args.save_steps,
1409
- bf16=runtime["bf16"],
1410
- fp16=runtime["fp16"],
1411
- report_to="none",
1412
- remove_unused_columns=False,
1413
- )
1414
 
1415
  print(
1416
  f"Training runtime: device={runtime['device']} "
1417
  f"name={runtime['device_name']} "
1418
  f"dtype={runtime['dtype']}"
1419
  )
 
 
 
 
 
1420
 
1421
- trainer = GRPOTrainer(
1422
  model=model,
1423
- reward_funcs=reward_fn,
1424
- args=config,
1425
  train_dataset=train_dataset,
1426
- processing_class=tokenizer,
 
 
 
1427
  )
1428
  trainer.train()
1429
  trainer.save_model(args.output_dir)
 
1266
  }
1267
 
1268
 
1269
+ def _guard_invalid_torchao_version() -> None:
1270
+ """Treat malformed torchao installs as unavailable for HF imports."""
1271
+ import functools
1272
+ import importlib.metadata as importlib_metadata
1273
+ import sys
1274
+ from packaging.version import InvalidVersion, Version
1275
+
1276
+ if getattr(importlib_metadata, "_openenv_torchao_guard_installed", False):
1277
+ metadata_guard_installed = True
1278
+ else:
1279
+ original_version = importlib_metadata.version
1280
+
1281
+ def guarded_version(distribution_name: str) -> str:
1282
+ version = original_version(distribution_name)
1283
+ if distribution_name.lower() == "torchao":
1284
+ try:
1285
+ Version(version)
1286
+ except InvalidVersion as exc:
1287
+ raise importlib_metadata.PackageNotFoundError(
1288
+ f"Malformed torchao version metadata: {version!r}"
1289
+ ) from exc
1290
+ return version
1291
+
1292
+ importlib_metadata.version = guarded_version
1293
+ importlib_metadata._openenv_torchao_guard_installed = True
1294
+ metadata_guard_installed = False
1295
+
1296
+ import_utils = sys.modules.get("transformers.utils.import_utils")
1297
+ if import_utils is not None and not getattr(import_utils, "_openenv_torchao_guard_installed", False):
1298
+ original_is_package_available = import_utils._is_package_available
1299
+
1300
+ def guarded_is_package_available(
1301
+ pkg_name: str,
1302
+ return_version: bool = False,
1303
+ ):
1304
+ if pkg_name != "torchao":
1305
+ return original_is_package_available(pkg_name, return_version=return_version)
1306
+ is_available, package_version = original_is_package_available(
1307
+ pkg_name,
1308
+ return_version=True,
1309
+ )
1310
+ if not is_available:
1311
+ return (False, package_version) if return_version else (False, None)
1312
+ try:
1313
+ Version(package_version)
1314
+ except InvalidVersion:
1315
+ return (False, "0") if return_version else (False, None)
1316
+ return (True, package_version) if return_version else (True, None)
1317
+
1318
+ min_version = getattr(import_utils, "TORCHAO_MIN_VERSION", "0")
1319
+
1320
+ @functools.lru_cache
1321
+ def guarded_is_torchao_available(min_version_override: str = min_version) -> bool:
1322
+ is_available, package_version = guarded_is_package_available(
1323
+ "torchao",
1324
+ return_version=True,
1325
+ )
1326
+ if not is_available:
1327
+ return False
1328
+ try:
1329
+ return Version(package_version) >= Version(min_version_override)
1330
+ except InvalidVersion:
1331
+ return False
1332
+
1333
+ if hasattr(import_utils.is_torchao_available, "cache_clear"):
1334
+ import_utils.is_torchao_available.cache_clear()
1335
+ import_utils._is_package_available = guarded_is_package_available
1336
+ import_utils.is_torchao_available = guarded_is_torchao_available
1337
+ import_utils._openenv_torchao_guard_installed = True
1338
+
1339
+ transformers_utils = sys.modules.get("transformers.utils")
1340
+ if transformers_utils is not None:
1341
+ transformers_utils.is_torchao_available = guarded_is_torchao_available
1342
+
1343
+ if metadata_guard_installed and import_utils is None:
1344
+ return
1345
+
1346
+
1347
+ def _guard_partial_vllm_install() -> None:
1348
+ """Treat partial vLLM installs as unavailable for TRL imports."""
1349
+ import functools
1350
+ import importlib
1351
+
1352
+ try:
1353
+ import trl.import_utils as trl_import_utils
1354
+ except Exception:
1355
+ return
1356
+
1357
+ if getattr(trl_import_utils, "_openenv_vllm_guard_installed", False):
1358
+ return
1359
+
1360
+ def _has_usable_vllm() -> bool:
1361
+ try:
1362
+ importlib.import_module("vllm")
1363
+ importlib.import_module("vllm.distributed.device_communicators.pynccl")
1364
+ importlib.import_module("vllm.distributed.utils")
1365
+ except Exception:
1366
+ return False
1367
+ return True
1368
+
1369
+ @functools.lru_cache
1370
+ def guarded_is_vllm_available(*args: Any, **kwargs: Any) -> bool:
1371
+ return _has_usable_vllm()
1372
+
1373
+ if hasattr(trl_import_utils.is_vllm_available, "cache_clear"):
1374
+ trl_import_utils.is_vllm_available.cache_clear()
1375
+ trl_import_utils.is_vllm_available = guarded_is_vllm_available
1376
+ trl_import_utils._openenv_vllm_guard_installed = True
1377
+
1378
+
1379
  def load_model_artifacts(
1380
  model_id: str,
1381
  *,
1382
  trust_remote_code: bool,
1383
  ):
1384
+ _guard_invalid_torchao_version()
1385
  from transformers import AutoModelForCausalLM, AutoTokenizer
1386
 
1387
  runtime = resolve_torch_runtime()
 
1406
  return tokenizer, model
1407
 
1408
 
1409
+ def build_openenv_reward(args: argparse.Namespace) -> OpenEnvReward:
1410
+ """Return the OpenEnv-compatible reward callable used by GRPO."""
1411
+ return OpenEnvReward(
1412
+ reward_backend=args.reward_backend,
1413
+ base_url=args.base_url,
1414
+ domain_randomise=args.domain_randomise,
1415
+ )
1416
+
1417
+
1418
+ def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
1419
+ """Build the OpenEnv rollout states that seed GRPO prompts."""
1420
+ scenario_names = selected_scenarios(args.scenario_name)
1421
+ examples = build_prompt_examples(
1422
+ dataset_episodes=args.dataset_episodes,
1423
+ rollout_steps=args.rollout_steps,
1424
+ collection_policy=args.collection_policy,
1425
+ scenario_names=scenario_names,
1426
+ seed=args.seed,
1427
+ domain_randomise=args.domain_randomise,
1428
+ )
1429
+ return {
1430
+ "scenario_names": scenario_names,
1431
+ "examples": examples,
1432
+ }
1433
+
1434
+
1435
+ def build_grpo_config(
1436
+ args: argparse.Namespace,
1437
+ runtime: Dict[str, Any],
1438
+ ):
1439
+ import inspect
1440
+
1441
+ _guard_invalid_torchao_version()
1442
+ _guard_partial_vllm_install()
1443
+ from trl import GRPOConfig
1444
+
1445
+ config_kwargs = {
1446
+ "output_dir": args.output_dir,
1447
+ "learning_rate": args.learning_rate,
1448
+ "per_device_train_batch_size": args.per_device_train_batch_size,
1449
+ "gradient_accumulation_steps": args.gradient_accumulation_steps,
1450
+ "num_generations": args.num_generations,
1451
+ "max_completion_length": args.max_completion_length,
1452
+ "max_prompt_length": args.max_prompt_length,
1453
+ "num_train_epochs": args.num_train_epochs,
1454
+ "logging_steps": args.logging_steps,
1455
+ "save_steps": args.save_steps,
1456
+ "bf16": runtime["bf16"],
1457
+ "fp16": runtime["fp16"],
1458
+ "report_to": "none",
1459
+ "remove_unused_columns": False,
1460
+ }
1461
+ supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
1462
+
1463
+ # Older TRL builds may expose a single max_length knob instead of
1464
+ # separate prompt/completion limits.
1465
+ if (
1466
+ "max_length" in supported_params
1467
+ and "max_prompt_length" not in supported_params
1468
+ and "max_completion_length" not in supported_params
1469
+ ):
1470
+ config_kwargs["max_length"] = (
1471
+ args.max_prompt_length + args.max_completion_length
1472
+ )
1473
+
1474
+ filtered_kwargs = {
1475
+ key: value
1476
+ for key, value in config_kwargs.items()
1477
+ if key in supported_params
1478
+ }
1479
+ skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
1480
+ if skipped:
1481
+ print(
1482
+ "GRPOConfig compatibility: skipping unsupported fields "
1483
+ f"{', '.join(skipped)}"
1484
+ )
1485
+
1486
+ return GRPOConfig(**filtered_kwargs)
1487
+
1488
+
1489
+ def build_grpo_trainer(
1490
+ *,
1491
+ model: Any,
1492
+ tokenizer: Any,
1493
+ reward_func: Any,
1494
+ train_dataset: Any,
1495
+ args: argparse.Namespace,
1496
+ runtime: Dict[str, Any],
1497
+ ):
1498
+ _guard_invalid_torchao_version()
1499
+ _guard_partial_vllm_install()
1500
+ from trl import GRPOTrainer
1501
+
1502
+ config = build_grpo_config(args, runtime)
1503
+ return GRPOTrainer(
1504
+ model=model,
1505
+ reward_funcs=reward_func,
1506
+ args=config,
1507
+ train_dataset=train_dataset,
1508
+ processing_class=tokenizer,
1509
+ )
1510
+
1511
+
1512
  def generate_action_with_model(
1513
  model: Any,
1514
  tokenizer: Any,
 
1577
  "model": model,
1578
  }
1579
 
1580
+ prompt_data = prepare_prompt_examples(args)
1581
+ scenario_names = prompt_data["scenario_names"]
1582
+ examples = prompt_data["examples"]
1583
+ reward_fn = build_openenv_reward(args)
 
 
 
 
 
 
 
 
 
 
1584
 
1585
  if args.dry_run:
1586
  run_dry_run_preview(examples, reward_fn, args.output_dir)
 
1593
  }
1594
 
1595
  from datasets import Dataset
 
 
1596
  train_dataset = Dataset.from_list(examples)
1597
  tokenizer, model = load_model_artifacts(
1598
  args.model_id,
1599
  trust_remote_code=args.trust_remote_code,
1600
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1601
 
1602
  print(
1603
  f"Training runtime: device={runtime['device']} "
1604
  f"name={runtime['device_name']} "
1605
  f"dtype={runtime['dtype']}"
1606
  )
1607
+ print(
1608
+ "OpenEnv reward: "
1609
+ f"backend={args.reward_backend} scenarios={len(scenario_names)} "
1610
+ f"examples={len(examples)}"
1611
+ )
1612
 
1613
+ trainer = build_grpo_trainer(
1614
  model=model,
 
 
1615
  train_dataset=train_dataset,
1616
+ tokenizer=tokenizer,
1617
+ reward_func=reward_fn,
1618
+ args=args,
1619
+ runtime=runtime,
1620
  )
1621
  trainer.train()
1622
  trainer.save_model(args.output_dir)
training_unsloth.py CHANGED
@@ -6,6 +6,9 @@ but arranges the Unsloth path in the more typical pattern:
6
  2. load a quantized model
7
  3. apply LoRA adapters
8
  4. train with an explicit OpenEnv reward function
 
 
 
9
  """
10
 
11
  from __future__ import annotations
@@ -15,6 +18,9 @@ import random
15
  from pathlib import Path
16
  from typing import Any, Dict, Optional, Sequence
17
 
 
 
 
18
  import training_script as base
19
 
20
  DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
@@ -36,11 +42,21 @@ LORA_TARGET_MODULES = [
36
  def require_unsloth():
37
  try:
38
  from unsloth import FastLanguageModel, PatchFastRL
39
- except ImportError as exc: # pragma: no cover - depends on optional extra
40
- raise RuntimeError(
41
- "Unsloth is not installed. Run `uv sync --extra train` "
42
- "to install the H100/quantized training dependencies."
43
- ) from exc
 
 
 
 
 
 
 
 
 
 
44
  return FastLanguageModel, PatchFastRL
45
 
46
 
@@ -74,11 +90,6 @@ def build_argument_parser() -> argparse.ArgumentParser:
74
  action="store_true",
75
  help="Disable 4-bit quantized loading and use the wider base weights.",
76
  )
77
- parser.add_argument(
78
- "--disable-fast-inference",
79
- action="store_true",
80
- help="Disable Unsloth fast inference kernels where supported.",
81
- )
82
  parser.add_argument(
83
  "--lora-r",
84
  type=int,
@@ -125,7 +136,7 @@ def load_model_artifacts(
125
  trust_remote_code: bool,
126
  max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
127
  load_in_4bit: bool = True,
128
- fast_inference: bool = True,
129
  prepare_for_inference: bool = False,
130
  ):
131
  FastLanguageModel, _ = require_unsloth()
@@ -136,7 +147,7 @@ def load_model_artifacts(
136
  FastLanguageModel,
137
  model_name=model_id,
138
  max_seq_length=max_seq_length,
139
- dtype="auto",
140
  load_in_4bit=load_in_4bit,
141
  fast_inference=fast_inference,
142
  trust_remote_code=trust_remote_code,
@@ -211,23 +222,43 @@ def build_grpo_config(
211
  args: argparse.Namespace,
212
  runtime: Dict[str, Any],
213
  ):
 
 
 
 
214
  from trl import GRPOConfig
215
 
216
- return GRPOConfig(
217
- output_dir=args.output_dir,
218
- learning_rate=args.learning_rate,
219
- per_device_train_batch_size=args.per_device_train_batch_size,
220
- gradient_accumulation_steps=args.gradient_accumulation_steps,
221
- num_generations=args.num_generations,
222
- max_completion_length=args.max_completion_length,
223
- num_train_epochs=args.num_train_epochs,
224
- logging_steps=args.logging_steps,
225
- save_steps=args.save_steps,
226
- bf16=runtime["bf16"],
227
- fp16=runtime["fp16"],
228
- report_to="none",
229
- remove_unused_columns=False,
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  def build_unsloth_grpo_trainer(
@@ -239,6 +270,8 @@ def build_unsloth_grpo_trainer(
239
  args: argparse.Namespace,
240
  runtime: Dict[str, Any],
241
  ):
 
 
242
  from trl import GRPOTrainer
243
 
244
  config = build_grpo_config(args, runtime)
@@ -311,7 +344,7 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
311
  trust_remote_code=args.trust_remote_code,
312
  max_seq_length=args.max_seq_length,
313
  load_in_4bit=not args.disable_4bit,
314
- fast_inference=not args.disable_fast_inference,
315
  prepare_for_inference=True,
316
  )
317
  device = getattr(model, "device", "unknown")
@@ -352,7 +385,7 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
352
  trust_remote_code=args.trust_remote_code,
353
  max_seq_length=args.max_seq_length,
354
  load_in_4bit=not args.disable_4bit,
355
- fast_inference=not args.disable_fast_inference,
356
  )
357
  # 2. Apply LoRA adapters.
358
  model = apply_lora_adapters(FastLanguageModel, model, args)
@@ -378,6 +411,11 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
378
  args=args,
379
  runtime=runtime,
380
  )
 
 
 
 
 
381
  trainer.train()
382
  trainer.save_model(args.output_dir)
383
  tokenizer.save_pretrained(args.output_dir)
 
6
  2. load a quantized model
7
  3. apply LoRA adapters
8
  4. train with an explicit OpenEnv reward function
9
+
10
+ NOTE: Unsloth must be imported before trl, transformers, peft. Import this
11
+ module before training_script.
12
  """
13
 
14
  from __future__ import annotations
 
18
  from pathlib import Path
19
  from typing import Any, Dict, Optional, Sequence
20
 
21
+ # Unsloth must be imported before trl/transformers/peft for optimizations.
22
+ import unsloth # noqa: F401
23
+
24
  import training_script as base
25
 
26
  DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
 
42
  def require_unsloth():
43
  try:
44
  from unsloth import FastLanguageModel, PatchFastRL
45
+ except ImportError as exc:
46
+ msg = str(exc)
47
+ if "vllm.lora" in msg or "vllm" in msg.lower():
48
+ raise RuntimeError(
49
+ f"Unsloth failed: {exc}. "
50
+ "unsloth_zoo expects vllm.lora.models. Install a compatible vllm:\n"
51
+ " pip install 'vllm==0.8.2' # requires torch 2.6\n"
52
+ " pip install 'vllm==0.7.3' # alternative\n"
53
+ "If torch>=2.10 conflicts, use a separate env with torch 2.6–2.8."
54
+ ) from exc
55
+ if "unsloth" in msg.lower():
56
+ raise RuntimeError(
57
+ "Unsloth is not installed. Run `uv sync` or `pip install unsloth`."
58
+ ) from exc
59
+ raise RuntimeError(f"Failed to import Unsloth: {exc}") from exc
60
  return FastLanguageModel, PatchFastRL
61
 
62
 
 
90
  action="store_true",
91
  help="Disable 4-bit quantized loading and use the wider base weights.",
92
  )
 
 
 
 
 
93
  parser.add_argument(
94
  "--lora-r",
95
  type=int,
 
136
  trust_remote_code: bool,
137
  max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
138
  load_in_4bit: bool = True,
139
+ fast_inference: bool = False,
140
  prepare_for_inference: bool = False,
141
  ):
142
  FastLanguageModel, _ = require_unsloth()
 
147
  FastLanguageModel,
148
  model_name=model_id,
149
  max_seq_length=max_seq_length,
150
+ dtype=runtime["dtype"],
151
  load_in_4bit=load_in_4bit,
152
  fast_inference=fast_inference,
153
  trust_remote_code=trust_remote_code,
 
222
  args: argparse.Namespace,
223
  runtime: Dict[str, Any],
224
  ):
225
+ import inspect
226
+
227
+ base._guard_invalid_torchao_version()
228
+ base._guard_partial_vllm_install()
229
  from trl import GRPOConfig
230
 
231
+ supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
232
+ config_kwargs = {
233
+ "output_dir": args.output_dir,
234
+ "learning_rate": args.learning_rate,
235
+ "per_device_train_batch_size": args.per_device_train_batch_size,
236
+ "gradient_accumulation_steps": args.gradient_accumulation_steps,
237
+ "num_generations": args.num_generations,
238
+ "max_completion_length": args.max_completion_length,
239
+ "num_train_epochs": args.num_train_epochs,
240
+ "logging_steps": args.logging_steps,
241
+ "save_steps": args.save_steps,
242
+ "bf16": runtime["bf16"],
243
+ "fp16": runtime["fp16"],
244
+ "report_to": "none",
245
+ "remove_unused_columns": False,
246
+ }
247
+ # Only add max_prompt_length if this TRL version supports it; UnslothGRPOTrainer can
248
+ # fail when passing it to parent, so we only pass when explicitly supported.
249
+ if "max_prompt_length" in supported_params:
250
+ config_kwargs["max_prompt_length"] = None # text-only; avoids image_token_id crash
251
+ if (
252
+ "max_length" in supported_params
253
+ and "max_prompt_length" not in supported_params
254
+ and "max_completion_length" not in supported_params
255
+ ):
256
+ config_kwargs["max_length"] = getattr(args, "max_prompt_length", 1024) + args.max_completion_length
257
+ filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}
258
+ skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
259
+ if skipped:
260
+ print(f"GRPOConfig compatibility: skipping unsupported fields {', '.join(skipped)}")
261
+ return GRPOConfig(**filtered_kwargs)
262
 
263
 
264
  def build_unsloth_grpo_trainer(
 
270
  args: argparse.Namespace,
271
  runtime: Dict[str, Any],
272
  ):
273
+ base._guard_invalid_torchao_version()
274
+ base._guard_partial_vllm_install()
275
  from trl import GRPOTrainer
276
 
277
  config = build_grpo_config(args, runtime)
 
344
  trust_remote_code=args.trust_remote_code,
345
  max_seq_length=args.max_seq_length,
346
  load_in_4bit=not args.disable_4bit,
347
+ fast_inference=False,
348
  prepare_for_inference=True,
349
  )
350
  device = getattr(model, "device", "unknown")
 
385
  trust_remote_code=args.trust_remote_code,
386
  max_seq_length=args.max_seq_length,
387
  load_in_4bit=not args.disable_4bit,
388
+ fast_inference=False,
389
  )
390
  # 2. Apply LoRA adapters.
391
  model = apply_lora_adapters(FastLanguageModel, model, args)
 
411
  args=args,
412
  runtime=runtime,
413
  )
414
+ # Workaround: UnslothGRPOTrainer expects vision token IDs for max_prompt_length
415
+ # truncation; text-only models don't have them. Set to None so protected=[].
416
+ for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
417
+ if not hasattr(trainer, attr):
418
+ setattr(trainer, attr, None)
419
  trainer.train()
420
  trainer.save_model(args.output_dir)
421
  tokenizer.save_pretrained(args.output_dir)
unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py ADDED
File without changes
unsloth_compiled_cache/UnslothCPOTrainer.py ADDED
@@ -0,0 +1,1936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothCPOConfig(CPOConfig):
325
+ """
326
+
327
+ Configuration class for the [`CPOTrainer`].
328
+
329
+ This class includes only the parameters that are specific to CPO training. For a full list of training arguments,
330
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
331
+ differ from those in [`~transformers.TrainingArguments`].
332
+
333
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
334
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
335
+ command line.
336
+
337
+ Parameters:
338
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
339
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
340
+ to use the default data collator.
341
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
342
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
343
+ max_completion_length (`int`, *optional*):
344
+ Maximum length of the completion. This argument is required if you want to use the default data collator
345
+ and your model is an encoder-decoder.
346
+ beta (`float`, *optional*, defaults to `0.1`):
347
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
348
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
349
+ the [paper](https://huggingface.co/papers/2310.12036).
350
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
351
+ Label smoothing factor. This argument is required if you want to use the default data collator.
352
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
353
+ Type of loss to use. Possible values are:
354
+
355
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
356
+ - `"hinge"`: hinge loss on the normalized likelihood from the
357
+ [SLiC](https://huggingface.co/papers/2305.10425) paper.
358
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
359
+ - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
360
+ - `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This
361
+ automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`.
362
+
363
+ disable_dropout (`bool`, *optional*, defaults to `True`):
364
+ Whether to disable dropout in the model.
365
+ cpo_alpha (`float`, *optional*, defaults to `1.0`):
366
+ Weight of the BC regularizer in CPO training.
367
+ simpo_gamma (`float`, *optional*, defaults to `0.5`):
368
+ Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
369
+ alpha (`float`, *optional*, defaults to `0.0`):
370
+ Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses
371
+ standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha))
372
+ / alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all
373
+ loss types.
374
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
375
+ Label pad token id. This argument is required if you want to use the default data collator.
376
+ padding_value (`int`, *optional*):
377
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
378
+ truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
379
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
380
+ This argument is required if you want to use the default data collator.
381
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
382
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
383
+ is_encoder_decoder (`bool`, *optional*):
384
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
385
+ you need to specify if the model returned by the callable is an encoder-decoder model.
386
+ model_init_kwargs (`dict[str, Any]`, *optional*):
387
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
388
+ string.
389
+ dataset_num_proc (`int`, *optional*):
390
+ Number of processes to use for processing the dataset.
391
+
392
+ """
393
+ vllm_sampling_params: Optional[Any] = field(
394
+ default = None,
395
+ metadata = {'help': 'vLLM SamplingParams'},
396
+ )
397
+ unsloth_num_chunks : Optional[int] = field(
398
+ default = -1,
399
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
400
+ )
401
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
402
+ default = None,
403
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
404
+ )
405
+ unsloth_grpo_mini_batch : Optional[int] = field(
406
+ default = None,
407
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
408
+ )
409
+ max_seq_length : Optional[int] = field(
410
+ default = None,
411
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
412
+ )
413
+ def __init__(
414
+ self,
415
+ output_dir = None,
416
+ per_device_train_batch_size = 4,
417
+ num_train_epochs = 3.0,
418
+ max_steps = -1,
419
+ learning_rate = 5e-05,
420
+ lr_scheduler_type = 'linear',
421
+ lr_scheduler_kwargs = None,
422
+ warmup_steps = 0.1,
423
+ optim = 'adamw_8bit',
424
+ optim_args = None,
425
+ weight_decay = 0.01,
426
+ adam_beta1 = 0.9,
427
+ adam_beta2 = 0.999,
428
+ adam_epsilon = 1e-08,
429
+ optim_target_modules = None,
430
+ gradient_accumulation_steps = 2,
431
+ average_tokens_across_devices = True,
432
+ max_grad_norm = 1.0,
433
+ label_smoothing_factor = 0.0,
434
+ bf16 = False,
435
+ fp16 = False,
436
+ bf16_full_eval = False,
437
+ fp16_full_eval = False,
438
+ tf32 = None,
439
+ gradient_checkpointing = True,
440
+ gradient_checkpointing_kwargs = None,
441
+ torch_compile = False,
442
+ torch_compile_backend = None,
443
+ torch_compile_mode = None,
444
+ use_liger_kernel = False,
445
+ liger_kernel_config = None,
446
+ use_cache = False,
447
+ neftune_noise_alpha = None,
448
+ torch_empty_cache_steps = 250,
449
+ auto_find_batch_size = False,
450
+ logging_strategy = 'steps',
451
+ logging_steps = 1,
452
+ logging_first_step = False,
453
+ log_on_each_node = True,
454
+ logging_nan_inf_filter = False,
455
+ include_num_input_tokens_seen = False,
456
+ log_level = 'passive',
457
+ log_level_replica = 'warning',
458
+ disable_tqdm = None,
459
+ report_to = 'none',
460
+ run_name = None,
461
+ project = 'huggingface',
462
+ trackio_space_id = 'trackio',
463
+ eval_strategy = 'no',
464
+ eval_steps = None,
465
+ eval_delay = 0,
466
+ per_device_eval_batch_size = 4,
467
+ prediction_loss_only = False,
468
+ eval_on_start = False,
469
+ eval_do_concat_batches = True,
470
+ eval_use_gather_object = False,
471
+ eval_accumulation_steps = 2,
472
+ batch_eval_metrics = False,
473
+ save_only_model = False,
474
+ save_strategy = 'steps',
475
+ save_steps = 500,
476
+ save_on_each_node = False,
477
+ save_total_limit = None,
478
+ enable_jit_checkpoint = False,
479
+ push_to_hub = False,
480
+ hub_token = None,
481
+ hub_private_repo = None,
482
+ hub_model_id = None,
483
+ hub_strategy = 'every_save',
484
+ hub_always_push = False,
485
+ hub_revision = None,
486
+ load_best_model_at_end = False,
487
+ metric_for_best_model = None,
488
+ greater_is_better = None,
489
+ ignore_data_skip = False,
490
+ restore_callback_states_from_checkpoint = False,
491
+ full_determinism = False,
492
+ seed = 3407,
493
+ data_seed = 3407,
494
+ use_cpu = False,
495
+ accelerator_config = None,
496
+ parallelism_config = None,
497
+ dataloader_drop_last = False,
498
+ dataloader_num_workers = 0,
499
+ dataloader_pin_memory = True,
500
+ dataloader_persistent_workers = False,
501
+ dataloader_prefetch_factor = None,
502
+ remove_unused_columns = True,
503
+ label_names = None,
504
+ train_sampling_strategy = 'random',
505
+ length_column_name = 'length',
506
+ ddp_find_unused_parameters = None,
507
+ ddp_bucket_cap_mb = None,
508
+ ddp_broadcast_buffers = None,
509
+ ddp_backend = None,
510
+ ddp_timeout = 1800,
511
+ fsdp = None,
512
+ fsdp_config = None,
513
+ deepspeed = None,
514
+ debug = '',
515
+ skip_memory_metrics = True,
516
+ do_train = False,
517
+ do_eval = False,
518
+ do_predict = False,
519
+ resume_from_checkpoint = None,
520
+ warmup_ratio = None,
521
+ logging_dir = None,
522
+ local_rank = -1,
523
+ max_length = 1024,
524
+ max_prompt_length = 512,
525
+ max_completion_length = None,
526
+ beta = 0.1,
527
+ label_smoothing = 0.0,
528
+ loss_type = 'sigmoid',
529
+ disable_dropout = True,
530
+ cpo_alpha = 1.0,
531
+ simpo_gamma = 0.5,
532
+ alpha = 0.0,
533
+ label_pad_token_id = -100,
534
+ padding_value = None,
535
+ truncation_mode = 'keep_end',
536
+ generate_during_eval = False,
537
+ is_encoder_decoder = None,
538
+ model_init_kwargs = None,
539
+ dataset_num_proc = None,
540
+ vllm_sampling_params = None,
541
+ unsloth_num_chunks = -1,
542
+ unsloth_logit_chunk_multiplier = None,
543
+ unsloth_grpo_mini_batch = None,
544
+ max_seq_length = None,
545
+ **kwargs,
546
+ ):
547
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
548
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
549
+ if num_train_epochs is None:
550
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
551
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
552
+ output_dir = 'unsloth_training_checkpoints'
553
+ save_strategy = 'no'
554
+ import multiprocessing as _mp
555
+ if _mp.get_start_method() != 'fork':
556
+ dataset_num_proc = None
557
+ elif dataset_num_proc is None:
558
+ import psutil
559
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
560
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
561
+ if memory_gb_left <= 2: dataset_num_proc = 1
562
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
563
+
564
+ super().__init__(
565
+ output_dir = output_dir,
566
+ per_device_train_batch_size = per_device_train_batch_size,
567
+ num_train_epochs = num_train_epochs,
568
+ max_steps = max_steps,
569
+ learning_rate = learning_rate,
570
+ lr_scheduler_type = lr_scheduler_type,
571
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
572
+ warmup_steps = warmup_steps,
573
+ optim = optim,
574
+ optim_args = optim_args,
575
+ weight_decay = weight_decay,
576
+ adam_beta1 = adam_beta1,
577
+ adam_beta2 = adam_beta2,
578
+ adam_epsilon = adam_epsilon,
579
+ optim_target_modules = optim_target_modules,
580
+ gradient_accumulation_steps = gradient_accumulation_steps,
581
+ average_tokens_across_devices = average_tokens_across_devices,
582
+ max_grad_norm = max_grad_norm,
583
+ label_smoothing_factor = label_smoothing_factor,
584
+ bf16 = bf16,
585
+ fp16 = fp16,
586
+ bf16_full_eval = bf16_full_eval,
587
+ fp16_full_eval = fp16_full_eval,
588
+ tf32 = tf32,
589
+ gradient_checkpointing = gradient_checkpointing,
590
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
591
+ torch_compile = torch_compile,
592
+ torch_compile_backend = torch_compile_backend,
593
+ torch_compile_mode = torch_compile_mode,
594
+ use_liger_kernel = use_liger_kernel,
595
+ liger_kernel_config = liger_kernel_config,
596
+ use_cache = use_cache,
597
+ neftune_noise_alpha = neftune_noise_alpha,
598
+ torch_empty_cache_steps = torch_empty_cache_steps,
599
+ auto_find_batch_size = auto_find_batch_size,
600
+ logging_strategy = logging_strategy,
601
+ logging_steps = logging_steps,
602
+ logging_first_step = logging_first_step,
603
+ log_on_each_node = log_on_each_node,
604
+ logging_nan_inf_filter = logging_nan_inf_filter,
605
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
606
+ log_level = log_level,
607
+ log_level_replica = log_level_replica,
608
+ disable_tqdm = disable_tqdm,
609
+ report_to = report_to,
610
+ run_name = run_name,
611
+ project = project,
612
+ trackio_space_id = trackio_space_id,
613
+ eval_strategy = eval_strategy,
614
+ eval_steps = eval_steps,
615
+ eval_delay = eval_delay,
616
+ per_device_eval_batch_size = per_device_eval_batch_size,
617
+ prediction_loss_only = prediction_loss_only,
618
+ eval_on_start = eval_on_start,
619
+ eval_do_concat_batches = eval_do_concat_batches,
620
+ eval_use_gather_object = eval_use_gather_object,
621
+ eval_accumulation_steps = eval_accumulation_steps,
622
+ batch_eval_metrics = batch_eval_metrics,
623
+ save_only_model = save_only_model,
624
+ save_strategy = save_strategy,
625
+ save_steps = save_steps,
626
+ save_on_each_node = save_on_each_node,
627
+ save_total_limit = save_total_limit,
628
+ enable_jit_checkpoint = enable_jit_checkpoint,
629
+ push_to_hub = push_to_hub,
630
+ hub_token = hub_token,
631
+ hub_private_repo = hub_private_repo,
632
+ hub_model_id = hub_model_id,
633
+ hub_strategy = hub_strategy,
634
+ hub_always_push = hub_always_push,
635
+ hub_revision = hub_revision,
636
+ load_best_model_at_end = load_best_model_at_end,
637
+ metric_for_best_model = metric_for_best_model,
638
+ greater_is_better = greater_is_better,
639
+ ignore_data_skip = ignore_data_skip,
640
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
641
+ full_determinism = full_determinism,
642
+ seed = seed,
643
+ data_seed = data_seed,
644
+ use_cpu = use_cpu,
645
+ accelerator_config = accelerator_config,
646
+ parallelism_config = parallelism_config,
647
+ dataloader_drop_last = dataloader_drop_last,
648
+ dataloader_num_workers = dataloader_num_workers,
649
+ dataloader_pin_memory = dataloader_pin_memory,
650
+ dataloader_persistent_workers = dataloader_persistent_workers,
651
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
652
+ remove_unused_columns = remove_unused_columns,
653
+ label_names = label_names,
654
+ train_sampling_strategy = train_sampling_strategy,
655
+ length_column_name = length_column_name,
656
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
657
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
658
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
659
+ ddp_backend = ddp_backend,
660
+ ddp_timeout = ddp_timeout,
661
+ fsdp = fsdp,
662
+ fsdp_config = fsdp_config,
663
+ deepspeed = deepspeed,
664
+ debug = debug,
665
+ skip_memory_metrics = skip_memory_metrics,
666
+ do_train = do_train,
667
+ do_eval = do_eval,
668
+ do_predict = do_predict,
669
+ resume_from_checkpoint = resume_from_checkpoint,
670
+ warmup_ratio = warmup_ratio,
671
+ logging_dir = logging_dir,
672
+ local_rank = local_rank,
673
+ max_length = max_length,
674
+ max_prompt_length = max_prompt_length,
675
+ max_completion_length = max_completion_length,
676
+ beta = beta,
677
+ label_smoothing = label_smoothing,
678
+ loss_type = loss_type,
679
+ disable_dropout = disable_dropout,
680
+ cpo_alpha = cpo_alpha,
681
+ simpo_gamma = simpo_gamma,
682
+ alpha = alpha,
683
+ label_pad_token_id = label_pad_token_id,
684
+ padding_value = padding_value,
685
+ truncation_mode = truncation_mode,
686
+ generate_during_eval = generate_during_eval,
687
+ is_encoder_decoder = is_encoder_decoder,
688
+ model_init_kwargs = model_init_kwargs,
689
+ dataset_num_proc = dataset_num_proc,**kwargs)
690
+ self.vllm_sampling_params = vllm_sampling_params
691
+ self.unsloth_num_chunks = unsloth_num_chunks
692
+ if unsloth_grpo_mini_batch is not None:
693
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
694
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
695
+ else:
696
+ raise ValueError(
697
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
698
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
699
+ )
700
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
701
+ self.max_seq_length = max_seq_length
702
+
703
+ pass
704
+
705
+ class _UnslothCPOTrainer(BaseTrainer):
706
+ r""""""
707
+
708
+ _tag_names = ["trl", "cpo"]
709
+ _name = "CPO"
710
+ _paper = {
711
+ "title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
712
+ "id": "2401.08417",
713
+ # docstyle-ignore
714
+ "citation": textwrap.dedent("""\
715
+ @inproceedings{xu2024contrastive,
716
+ title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
717
+ author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
718
+ year = 2024,
719
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
720
+ publisher = {OpenReview.net},
721
+ url = {https://openreview.net/forum?id=51iwkioZpn}
722
+ }"""),
723
+ }
724
+
725
+ def __init__(
726
+ self,
727
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
728
+ args: Optional[CPOConfig] = None,
729
+ data_collator: Optional[DataCollator] = None,
730
+ train_dataset: Optional[Dataset] = None,
731
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
732
+ processing_class: Optional[
733
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
734
+ ] = None,
735
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
736
+ callbacks: Optional[list[TrainerCallback]] = None,
737
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
738
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
739
+ peft_config: Optional[dict] = None,
740
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
741
+ ):
742
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
743
+ warnings.warn(
744
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
745
+ "it and want it to remain, please share your comments here: "
746
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
747
+ "TRL_EXPERIMENTAL_SILENCE=1."
748
+ )
749
+ if args.model_init_kwargs is None:
750
+ model_init_kwargs = {}
751
+ elif not isinstance(model, str):
752
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
753
+ else:
754
+ model_init_kwargs = args.model_init_kwargs
755
+ dtype = model_init_kwargs.get("dtype")
756
+ if dtype is not None:
757
+ # Convert to `torch.dtype` if an str is passed
758
+ if isinstance(dtype, str) and dtype != "auto":
759
+ dtype = getattr(torch, dtype)
760
+ if dtype != "auto" and not isinstance(dtype, torch.dtype):
761
+ raise ValueError(
762
+ f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
763
+ )
764
+ model_init_kwargs["dtype"] = dtype
765
+
766
+ if isinstance(model, str):
767
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
768
+
769
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
770
+ # has been called in order to properly call autocast if needed.
771
+ self._peft_has_been_casted_to_bf16 = False
772
+
773
+ if not is_peft_available() and peft_config is not None:
774
+ raise ValueError(
775
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
776
+ )
777
+ elif is_peft_available() and peft_config is not None:
778
+ # if model is a peft model and we have a peft_config, we merge and unload it first
779
+ if isinstance(model, PeftModel):
780
+ model = model.merge_and_unload()
781
+
782
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
783
+ _support_gc_kwargs = hasattr(
784
+ args, "gradient_checkpointing_kwargs"
785
+ ) and "gradient_checkpointing_kwargs" in list(
786
+ inspect.signature(prepare_model_for_kbit_training).parameters
787
+ )
788
+
789
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
790
+
791
+ if _support_gc_kwargs:
792
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
793
+
794
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
795
+ elif args.gradient_checkpointing:
796
+ # For backward compatibility with older versions of transformers
797
+ if hasattr(model, "enable_input_require_grads"):
798
+ model.enable_input_require_grads()
799
+ else:
800
+
801
+ def make_inputs_require_grad(module, input, output):
802
+ output.requires_grad_(True)
803
+
804
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
805
+
806
+ # get peft model with the given config
807
+ model = model
808
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
809
+ peft_module_casting_to_bf16(model)
810
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
811
+ self._peft_has_been_casted_to_bf16 = True
812
+
813
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
814
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
815
+ # fail or completely fail.
816
+ elif args.gradient_checkpointing:
817
+ # For backward compatibility with older versions of transformers
818
+ if hasattr(model, "enable_input_require_grads"):
819
+ model.enable_input_require_grads()
820
+ else:
821
+
822
+ def make_inputs_require_grad(module, input, output):
823
+ output.requires_grad_(True)
824
+
825
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
826
+
827
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
828
+ raise ValueError(
829
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
830
+ " Please install `wandb` or `comet-ml` to resolve."
831
+ )
832
+
833
+ if model is not None:
834
+ self.is_encoder_decoder = model.config.is_encoder_decoder
835
+ elif args.is_encoder_decoder is None:
836
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
837
+ else:
838
+ self.is_encoder_decoder = args.is_encoder_decoder
839
+
840
+ if self.is_encoder_decoder:
841
+ self.decoder_start_token_id = model.config.decoder_start_token_id
842
+ self.pad_token_id = model.config.pad_token_id
843
+
844
+ if processing_class is None:
845
+ raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
846
+ if args.max_length is None:
847
+ logger.warning(
848
+ "`max_length` is not set in the CPOConfig's init"
849
+ " it will default to `512` by default, but you should do it yourself in the future.",
850
+ )
851
+ max_length = 512
852
+ else:
853
+ max_length = args.max_length
854
+ if args.max_prompt_length is None:
855
+ logger.warning(
856
+ "`max_prompt_length` is not set in the CPOConfig's init"
857
+ " it will default to `128` by default, but you should do it yourself in the future.",
858
+ )
859
+ max_prompt_length = 128
860
+ else:
861
+ max_prompt_length = args.max_prompt_length
862
+
863
+ if not max_prompt_length < max_length:
864
+ raise ValueError(
865
+ f"max_prompt_length ({max_prompt_length}) should be strictly less than max_length ({max_length})."
866
+ )
867
+
868
+ if args.max_completion_length is None and self.is_encoder_decoder:
869
+ logger.warning(
870
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
871
+ " it will default to `128` by default, but you should do it yourself in the future.",
872
+ )
873
+ max_completion_length = 128
874
+ else:
875
+ max_completion_length = args.max_completion_length
876
+
877
+ if data_collator is None:
878
+ data_collator = DPODataCollatorWithPadding(
879
+ pad_token_id=processing_class.pad_token_id,
880
+ label_pad_token_id=args.label_pad_token_id,
881
+ is_encoder_decoder=self.is_encoder_decoder,
882
+ )
883
+
884
+ if args.remove_unused_columns:
885
+ args.remove_unused_columns = False
886
+ # warn users
887
+ logger.warning(
888
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
889
+ " we have set it for you, but you should do it yourself in the future.",
890
+ )
891
+
892
+ self.use_dpo_data_collator = True
893
+ else:
894
+ self.use_dpo_data_collator = False
895
+
896
+ # Disable dropout in the model
897
+ if args.disable_dropout:
898
+ disable_dropout_in_model(model)
899
+
900
+ self.max_length = max_length
901
+ self.generate_during_eval = args.generate_during_eval
902
+ self.label_pad_token_id = args.label_pad_token_id
903
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
904
+ self.max_prompt_length = max_prompt_length
905
+ self.truncation_mode = args.truncation_mode
906
+ self.max_completion_length = max_completion_length
907
+ self.processing_class = processing_class
908
+
909
+ if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
910
+ logger.warning(
911
+ f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
912
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
913
+ )
914
+ if args.loss_type == "kto_pair":
915
+ raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
916
+
917
+ self.beta = args.beta
918
+ self.label_smoothing = args.label_smoothing
919
+ self.loss_type = args.loss_type
920
+ self.cpo_alpha = args.cpo_alpha
921
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
922
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
923
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
924
+ logger.warning(
925
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
926
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
927
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
928
+ "loss.",
929
+ )
930
+
931
+ if args.loss_type == "simpo":
932
+ self.simpo_gamma = args.simpo_gamma
933
+
934
+ # AlphaPO parameter for reward shaping
935
+ self.alpha = args.alpha
936
+
937
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
938
+
939
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
940
+ # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
941
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
942
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
943
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
944
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
945
+ # that the warning has already been issued.
946
+ model.warnings_issued["estimate_tokens"] = True
947
+
948
+ # Compute that only on the main process for faster data processing.
949
+ # see: https://github.com/huggingface/trl/pull/1255
950
+ with PartialState().main_process_first():
951
+ # Extract the prompt if needed, and apply the chat template if needed
952
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
953
+ train_dataset = train_dataset.map(
954
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
955
+ )
956
+ if eval_dataset is not None:
957
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
958
+ eval_dataset = eval_dataset.map(
959
+ maybe_apply_chat_template,
960
+ fn_kwargs={"tokenizer": processing_class},
961
+ num_proc=args.dataset_num_proc,
962
+ )
963
+
964
+ # tokenize the dataset
965
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
966
+ if eval_dataset is not None:
967
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
968
+
969
+ super().__init__(
970
+ model=model,
971
+ args=args,
972
+ data_collator=data_collator,
973
+ train_dataset=train_dataset,
974
+ eval_dataset=eval_dataset,
975
+ processing_class=processing_class,
976
+ model_init=model_init,
977
+ compute_metrics=compute_metrics,
978
+ callbacks=callbacks,
979
+ optimizers=optimizers,
980
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
981
+ )
982
+
983
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
984
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
985
+ # self.model_accepts_loss_kwargs to False to enable scaling.
986
+ self.model_accepts_loss_kwargs = False
987
+
988
+ # Add tags for models that have been loaded with the correct transformers version
989
+ if hasattr(self.model, "add_model_tags"):
990
+ self.model.add_model_tags(self._tag_names)
991
+
992
+ if not hasattr(self, "accelerator"):
993
+ raise AttributeError(
994
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
995
+ )
996
+
997
+ def build_tokenized_answer(self, prompt, answer):
998
+ """
999
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
1000
+ b)[len(enc(a)):]`. Reference:
1001
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
1002
+ """
1003
+
1004
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
1005
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
1006
+
1007
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
1008
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
1009
+
1010
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
1011
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
1012
+
1013
+ # Prepare input tokens for token by token comparison
1014
+ full_input_ids = np.array(full_tokenized["input_ids"])
1015
+
1016
+ if len(full_input_ids) != len(full_concat_input_ids):
1017
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
1018
+
1019
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
1020
+ # can be merged together when tokenizing prompt+answer. This could result
1021
+ # on the last token from the prompt being different when tokenized on its own
1022
+ # vs when done as prompt+answer.
1023
+ response_token_ids_start_idx = len(prompt_input_ids)
1024
+
1025
+ # If tokenized prompt is different than both prompt+answer, then it means the
1026
+ # last token has changed due to merging.
1027
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
1028
+ response_token_ids_start_idx -= 1
1029
+
1030
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
1031
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
1032
+
1033
+ if len(prompt_input_ids) != len(prompt_attention_mask):
1034
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
1035
+
1036
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
1037
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
1038
+
1039
+ return dict(
1040
+ prompt_input_ids=prompt_input_ids,
1041
+ prompt_attention_mask=prompt_attention_mask,
1042
+ input_ids=answer_input_ids,
1043
+ attention_mask=answer_attention_mask,
1044
+ )
1045
+
1046
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
1047
+ """Tokenize a single row from a CPO specific dataset.
1048
+
1049
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
1050
+ chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
1051
+ we truncate the chosen/rejected.
1052
+
1053
+ We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
1054
+ of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
1055
+ """
1056
+ batch = {}
1057
+ prompt = feature["prompt"]
1058
+ chosen = feature["chosen"]
1059
+ rejected = feature["rejected"]
1060
+
1061
+ if not self.is_encoder_decoder:
1062
+ # Check issues below for more details
1063
+ # 1. https://github.com/huggingface/trl/issues/907
1064
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
1065
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
1066
+
1067
+ if not isinstance(prompt, str):
1068
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
1069
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
1070
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
1071
+
1072
+ if not isinstance(chosen, str):
1073
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
1074
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
1075
+
1076
+ if not isinstance(rejected, str):
1077
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
1078
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
1079
+
1080
+ # Last prompt token might get merged by tokenizer and
1081
+ # it should not be included for generation if that happens
1082
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
1083
+
1084
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
1085
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
1086
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
1087
+
1088
+ for k, v in prompt_tokens.items():
1089
+ prompt_tokens[k] = v[:prompt_len_input_ids]
1090
+
1091
+ # Make sure prompts only have one different token at most an
1092
+ # and length only differs by 1 at most
1093
+ num_diff_tokens = sum(
1094
+ a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
1095
+ )
1096
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
1097
+ if num_diff_tokens > 1 or num_diff_len > 1:
1098
+ raise ValueError(
1099
+ "Chosen and rejected prompt_input_ids might only differ on the "
1100
+ "last token due to tokenizer merge ops."
1101
+ )
1102
+
1103
+ # add BOS token to head of prompt. Avoid adding if it's already there
1104
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
1105
+ self.processing_class.bos_token_id,
1106
+ prompt_len_input_ids,
1107
+ prompt_tokens,
1108
+ chosen_prompt_len_input_ids,
1109
+ chosen_tokens,
1110
+ rejected_prompt_len_input_ids,
1111
+ rejected_tokens,
1112
+ )
1113
+
1114
+ # add EOS token to end of answer. Avoid adding if it's already there
1115
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
1116
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
1117
+ )
1118
+
1119
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
1120
+
1121
+ # if combined sequence is too long, truncate the prompt
1122
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
1123
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
1124
+ if self.truncation_mode == "keep_start":
1125
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
1126
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
1127
+ elif self.truncation_mode == "keep_end":
1128
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
1129
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
1130
+ else:
1131
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
1132
+
1133
+ # if that's still too long, truncate the response
1134
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
1135
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
1136
+ for k in ["input_ids", "attention_mask"]:
1137
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
1138
+
1139
+ # Create labels
1140
+ chosen_sequence_tokens = {
1141
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
1142
+ }
1143
+ rejected_sequence_tokens = {
1144
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
1145
+ }
1146
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
1147
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
1148
+ self.label_pad_token_id
1149
+ ] * len(chosen_tokens["prompt_input_ids"])
1150
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
1151
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
1152
+ self.label_pad_token_id
1153
+ ] * len(rejected_tokens["prompt_input_ids"])
1154
+
1155
+ for k, toks in {
1156
+ "chosen_": chosen_sequence_tokens,
1157
+ "rejected_": rejected_sequence_tokens,
1158
+ "": prompt_tokens,
1159
+ }.items():
1160
+ for type_key, tokens in toks.items():
1161
+ if type_key == "token_type_ids":
1162
+ continue
1163
+ batch[f"{k}{type_key}"] = tokens
1164
+
1165
+ else:
1166
+ chosen_tokens = self.processing_class(
1167
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
1168
+ )
1169
+ rejected_tokens = self.processing_class(
1170
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
1171
+ )
1172
+ prompt_tokens = self.processing_class(
1173
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
1174
+ )
1175
+
1176
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
1177
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
1178
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
1179
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
1180
+
1181
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
1182
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
1183
+ labels=torch.tensor(batch["rejected_labels"])
1184
+ )
1185
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
1186
+ labels=torch.tensor(batch["chosen_labels"])
1187
+ )
1188
+
1189
+ return batch
1190
+
1191
+ @staticmethod
1192
+ def concatenated_inputs(
1193
+ batch: dict[str, Union[list, torch.LongTensor]],
1194
+ is_encoder_decoder: bool = False,
1195
+ label_pad_token_id: int = -100,
1196
+ padding_value: int = 0,
1197
+ device: Optional[torch.device] = None,
1198
+ ) -> dict[str, torch.LongTensor]:
1199
+ """Concatenate the chosen and rejected inputs into a single tensor.
1200
+
1201
+ Args:
1202
+ batch:
1203
+ A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
1204
+ of shape (batch_size, sequence_length).
1205
+ is_encoder_decoder:
1206
+ Whether the model is an encoder-decoder model.
1207
+ label_pad_token_id:
1208
+ The label pad token id.
1209
+ padding_value:
1210
+ The padding value to use for the concatenated inputs_ids.
1211
+ device:
1212
+ The device for the concatenated inputs.
1213
+
1214
+ Returns:
1215
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
1216
+ """
1217
+ concatenated_batch = {}
1218
+
1219
+ if is_encoder_decoder:
1220
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
1221
+ else:
1222
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
1223
+
1224
+ for k in batch:
1225
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
1226
+ if "labels" in k or is_encoder_decoder:
1227
+ pad_value = label_pad_token_id
1228
+ elif k.endswith("_input_ids"):
1229
+ pad_value = padding_value
1230
+ elif k.endswith("_attention_mask"):
1231
+ pad_value = 0
1232
+ concatenated_key = k.replace("chosen", "concatenated")
1233
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
1234
+ for k in batch:
1235
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
1236
+ if "labels" in k or is_encoder_decoder:
1237
+ pad_value = label_pad_token_id
1238
+ elif k.endswith("_input_ids"):
1239
+ pad_value = padding_value
1240
+ elif k.endswith("_attention_mask"):
1241
+ pad_value = 0
1242
+ concatenated_key = k.replace("rejected", "concatenated")
1243
+ concatenated_batch[concatenated_key] = torch.cat(
1244
+ (
1245
+ concatenated_batch[concatenated_key],
1246
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
1247
+ ),
1248
+ dim=0,
1249
+ ).to(device=device)
1250
+
1251
+ if is_encoder_decoder:
1252
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
1253
+ concatenated_batch["concatenated_attention_mask"] = (
1254
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
1255
+ )
1256
+
1257
+ return concatenated_batch
1258
+
1259
+ def cpo_loss(
1260
+ self,
1261
+ policy_chosen_logps: torch.FloatTensor,
1262
+ policy_rejected_logps: torch.FloatTensor,
1263
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1264
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
1265
+
1266
+ Args:
1267
+ policy_chosen_logps:
1268
+ Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
1269
+ policy_rejected_logps:
1270
+ Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
1271
+
1272
+ Returns:
1273
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO
1274
+ loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
1275
+ the chosen and rejected responses, respectively.
1276
+ """
1277
+ # Apply AlphaPO reward transformation if alpha != 0
1278
+ if self.alpha != 0.0:
1279
+ # Compute probabilities
1280
+ chosen_probs = torch.exp(policy_chosen_logps)
1281
+ rejected_probs = torch.exp(policy_rejected_logps)
1282
+
1283
+ # Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha
1284
+ policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha
1285
+ policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha
1286
+
1287
+ logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device)
1288
+ else:
1289
+ # Standard log probability rewards when alpha = 0
1290
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
1291
+
1292
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
1293
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
1294
+ # calculates a conservative CPO loss.
1295
+
1296
+ if self.loss_type == "simpo":
1297
+ gamma_logratios = self.simpo_gamma / self.beta
1298
+ logits = logits - gamma_logratios
1299
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
1300
+ losses = (
1301
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
1302
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
1303
+ )
1304
+ elif self.loss_type == "sigmoid":
1305
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
1306
+ losses = (
1307
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
1308
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
1309
+ )
1310
+ elif self.loss_type == "hinge":
1311
+ losses = torch.relu(1 - self.beta * logits)
1312
+ elif self.loss_type == "ipo":
1313
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
1314
+ losses = (logits - 1 / (2 * self.beta)) ** 2
1315
+ else:
1316
+ raise ValueError(
1317
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
1318
+ )
1319
+
1320
+ # Calculate rewards for logging
1321
+ if self.alpha != 0.0:
1322
+ # When using AlphaPO transformation, use the transformed rewards
1323
+ chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach()
1324
+ rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach()
1325
+ else:
1326
+ # Standard log probability rewards
1327
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
1328
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
1329
+
1330
+ return losses, chosen_rewards, rejected_rewards
1331
+
1332
+ @staticmethod
1333
+ def get_batch_logps(
1334
+ logits: torch.FloatTensor,
1335
+ labels: torch.LongTensor,
1336
+ average_log_prob: bool = False,
1337
+ label_pad_token_id: int = -100,
1338
+ is_encoder_decoder: bool = False,
1339
+ ) -> torch.FloatTensor:
1340
+ """Compute the log probabilities of the given labels under the given logits.
1341
+
1342
+ Args:
1343
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1344
+ labels:
1345
+ Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
1346
+ ignored. Shape: (batch_size, sequence_length)
1347
+ average_log_prob:
1348
+ If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
1349
+ log probabilities of the (non-masked) tokens.
1350
+ label_pad_token_id: The label pad token id.
1351
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
1352
+
1353
+ Returns:
1354
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
1355
+ given logits.
1356
+ """
1357
+ if logits.shape[:-1] != labels.shape:
1358
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1359
+
1360
+ if not is_encoder_decoder:
1361
+ labels = labels[:, 1:].clone()
1362
+ logits = logits[:, :-1, :]
1363
+ loss_mask = labels != label_pad_token_id
1364
+
1365
+ # dummy token; we'll ignore the losses on these tokens later
1366
+ labels[labels == label_pad_token_id] = 0
1367
+
1368
+ per_token_logps = selective_log_softmax(logits, labels)
1369
+
1370
+ if average_log_prob:
1371
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1372
+ else:
1373
+ return (per_token_logps * loss_mask).sum(-1)
1374
+
1375
+ def concatenated_forward(
1376
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1377
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1378
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1379
+
1380
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1381
+ """
1382
+ concatenated_batch = self.concatenated_inputs(
1383
+ batch,
1384
+ is_encoder_decoder=self.is_encoder_decoder,
1385
+ label_pad_token_id=self.label_pad_token_id,
1386
+ padding_value=self.padding_value,
1387
+ device=self.accelerator.device,
1388
+ )
1389
+ len_chosen = batch["chosen_labels"].shape[0]
1390
+
1391
+ model_kwargs = (
1392
+ {
1393
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1394
+ }
1395
+ if self.is_encoder_decoder
1396
+ else {}
1397
+ )
1398
+
1399
+ if self.aux_loss_enabled:
1400
+ model_kwargs["output_router_logits"] = True
1401
+
1402
+ outputs = model(
1403
+ concatenated_batch["concatenated_input_ids"],
1404
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1405
+ use_cache=False,
1406
+ **model_kwargs,
1407
+ )
1408
+ all_logits = outputs.logits
1409
+
1410
+ def cross_entropy_loss(logits, labels):
1411
+ if not self.is_encoder_decoder:
1412
+ # Shift so that tokens < n predict n
1413
+ logits = logits[..., :-1, :].contiguous()
1414
+ labels = labels[..., 1:].contiguous()
1415
+ # Flatten the tokens
1416
+ loss_fct = nn.CrossEntropyLoss()
1417
+ logits = logits.view(-1, logits.shape[-1])
1418
+ labels = labels.view(-1)
1419
+ # Enable model parallelism
1420
+ labels = labels.to(logits.device)
1421
+ loss = loss_fct(logits, labels)
1422
+ return loss
1423
+
1424
+ labels = concatenated_batch["concatenated_labels"].clone()
1425
+
1426
+ if self.cpo_alpha == 0:
1427
+ nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1428
+ else:
1429
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1430
+
1431
+ all_logps = self.get_batch_logps(
1432
+ all_logits,
1433
+ concatenated_batch["concatenated_labels"],
1434
+ average_log_prob=self.loss_type in ["ipo", "simpo"],
1435
+ is_encoder_decoder=self.is_encoder_decoder,
1436
+ label_pad_token_id=self.label_pad_token_id,
1437
+ )
1438
+
1439
+ chosen_logps = all_logps[:len_chosen]
1440
+ rejected_logps = all_logps[len_chosen:]
1441
+
1442
+ chosen_logits = all_logits[:len_chosen]
1443
+ rejected_logits = all_logits[len_chosen:]
1444
+
1445
+ if self.aux_loss_enabled:
1446
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1447
+
1448
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1449
+
1450
+ def get_batch_loss_metrics(
1451
+ self,
1452
+ model,
1453
+ batch: dict[str, Union[list, torch.LongTensor]],
1454
+ train_eval: Literal["train", "eval"] = "train",
1455
+ ):
1456
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1457
+ metrics = {}
1458
+
1459
+ forward_output = self.concatenated_forward(model, batch)
1460
+ (
1461
+ policy_chosen_logps,
1462
+ policy_rejected_logps,
1463
+ policy_chosen_logits,
1464
+ policy_rejected_logits,
1465
+ policy_nll_loss,
1466
+ ) = forward_output[:5]
1467
+ if self.aux_loss_enabled:
1468
+ aux_loss = forward_output[5]
1469
+
1470
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1471
+ policy_chosen_logps,
1472
+ policy_rejected_logps,
1473
+ )
1474
+
1475
+ loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1476
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1477
+
1478
+ prefix = "eval_" if train_eval == "eval" else ""
1479
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1480
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1481
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1482
+ metrics[f"{prefix}rewards/margins"] = (
1483
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1484
+ )
1485
+ metrics[f"{prefix}logps/rejected"] = (
1486
+ self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1487
+ )
1488
+ metrics[f"{prefix}logps/chosen"] = (
1489
+ self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1490
+ )
1491
+ metrics[f"{prefix}logits/rejected"] = (
1492
+ self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item()
1493
+ )
1494
+ metrics[f"{prefix}logits/chosen"] = (
1495
+ self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item()
1496
+ )
1497
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1498
+
1499
+ if self.aux_loss_enabled:
1500
+ loss += self.aux_loss_coef * aux_loss
1501
+
1502
+ return loss, metrics
1503
+
1504
+ def compute_loss(
1505
+ self,
1506
+ model: Union[PreTrainedModel, nn.Module],
1507
+ inputs: dict[str, Union[torch.Tensor, Any]],
1508
+ return_outputs=False,
1509
+ num_items_in_batch=None,
1510
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1511
+ compute_loss_context_manager = (
1512
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1513
+ )
1514
+
1515
+ with compute_loss_context_manager:
1516
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1517
+
1518
+ # force log the metrics
1519
+ self.store_metrics(metrics, train_eval="train")
1520
+
1521
+ if return_outputs:
1522
+ return (loss, metrics)
1523
+ return loss
1524
+
1525
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1526
+ """Generate samples from the model and reference model for the given batch of inputs."""
1527
+
1528
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1529
+ # the torch amp context manager as some hidden states are silently casted to full precision.
1530
+ generate_context_manager = (
1531
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1532
+ )
1533
+
1534
+ with generate_context_manager:
1535
+ policy_output = model.generate(
1536
+ input_ids=batch["prompt_input_ids"],
1537
+ attention_mask=batch["prompt_attention_mask"],
1538
+ max_length=self.max_length,
1539
+ do_sample=True,
1540
+ pad_token_id=self.processing_class.pad_token_id,
1541
+ )
1542
+
1543
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1544
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1545
+
1546
+ return policy_output_decoded
1547
+
1548
+ def prediction_step(
1549
+ self,
1550
+ model: Union[PreTrainedModel, nn.Module],
1551
+ inputs: dict[str, Union[torch.Tensor, Any]],
1552
+ prediction_loss_only: bool,
1553
+ ignore_keys: Optional[list[str]] = None,
1554
+ ):
1555
+ if ignore_keys is None:
1556
+ if hasattr(model, "config"):
1557
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1558
+ else:
1559
+ ignore_keys = []
1560
+
1561
+ prediction_context_manager = (
1562
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1563
+ )
1564
+
1565
+ with torch.no_grad(), prediction_context_manager:
1566
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1567
+
1568
+ # force log the metrics
1569
+ self.store_metrics(metrics, train_eval="eval")
1570
+
1571
+ if prediction_loss_only:
1572
+ return (loss.detach(), None, None)
1573
+
1574
+ # logits for the chosen and rejected samples from model
1575
+ logits_dict = {
1576
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1577
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1578
+ }
1579
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1580
+ logits = torch.tensor(logits, device=self.accelerator.device)
1581
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1582
+
1583
+ return (loss.detach(), logits, labels)
1584
+
1585
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1586
+ for key, value in metrics.items():
1587
+ self._stored_metrics[train_eval][key].append(value)
1588
+
1589
+ def evaluation_loop(
1590
+ self,
1591
+ dataloader: DataLoader,
1592
+ description: str,
1593
+ prediction_loss_only: Optional[bool] = None,
1594
+ ignore_keys: Optional[list[str]] = None,
1595
+ metric_key_prefix: str = "eval",
1596
+ ) -> EvalLoopOutput:
1597
+ """
1598
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
1599
+ `Trainer.evaluate()` and `Trainer.predict()`.
1600
+
1601
+ Works both with or without labels.
1602
+ """
1603
+
1604
+ # Sample and save to game log if requested (for one batch to save time)
1605
+ if self.generate_during_eval:
1606
+ # Generate random indices within the range of the total number of samples
1607
+ num_samples = len(dataloader.dataset)
1608
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1609
+
1610
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1611
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1612
+ random_batch = self.data_collator(random_batch_dataset)
1613
+ random_batch = self._prepare_inputs(random_batch)
1614
+
1615
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1616
+
1617
+ table = pd.DataFrame(
1618
+ columns=["Prompt", "Policy"],
1619
+ data=[
1620
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1621
+ ],
1622
+ )
1623
+ if "wandb" in self.args.report_to:
1624
+ wandb.log({"game_log": wandb.Table(data=table)})
1625
+
1626
+ if "comet_ml" in self.args.report_to:
1627
+ log_table_to_comet_experiment(
1628
+ name="game_log.csv",
1629
+ table=table,
1630
+ )
1631
+
1632
+ # Base evaluation
1633
+ initial_output = super().evaluation_loop(
1634
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1635
+ )
1636
+
1637
+ return initial_output
1638
+
1639
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1640
+ """
1641
+ Log `logs` on the various objects watching training, including stored metrics.
1642
+
1643
+ Args:
1644
+ logs (`dict[str, float]`):
1645
+ The values to log.
1646
+ start_time (`float`, *optional*):
1647
+ Start time of the training.
1648
+ """
1649
+ # logs either has 'loss' or 'eval_loss'
1650
+ train_eval = "train" if "loss" in logs else "eval"
1651
+ # Add averaged stored metrics to logs
1652
+ for key, metrics in self._stored_metrics[train_eval].items():
1653
+ logs[key] = torch.tensor(metrics).mean().item()
1654
+ del self._stored_metrics[train_eval]
1655
+ return super().log(logs, start_time)
1656
+
1657
+ def _shift_right(self, input_ids):
1658
+ if self.decoder_start_token_id is None:
1659
+ raise ValueError(
1660
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1661
+ )
1662
+
1663
+ # shift inputs to the right
1664
+ if is_torch_fx_proxy(input_ids):
1665
+ # Item assignment is not supported natively for proxies.
1666
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1667
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1668
+ else:
1669
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1670
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1671
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1672
+
1673
+ if self.pad_token_id is None:
1674
+ raise ValueError("model.config.pad_token_id has to be defined.")
1675
+ # replace possible -100 values in labels by `pad_token_id`
1676
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1677
+
1678
+ return shifted_input_ids
1679
+
1680
+ # Ensure the model card is saved along with the checkpoint
1681
+ def _save_checkpoint(self, model, trial):
1682
+ if self.args.hub_model_id is None:
1683
+ model_name = Path(self.args.output_dir).name
1684
+ else:
1685
+ model_name = self.args.hub_model_id.split("/")[-1]
1686
+ self.create_model_card(model_name=model_name)
1687
+ super()._save_checkpoint(model, trial)
1688
+ class UnslothCPOTrainer(_UnslothCPOTrainer):
1689
+ """
1690
+
1691
+ Initialize CPOTrainer.
1692
+
1693
+ Args:
1694
+ model ([`~transformers.PreTrainedModel`]):
1695
+ The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
1696
+ args ([`CPOConfig`]):
1697
+ The CPO config arguments to use for training.
1698
+ data_collator ([`~transformers.DataCollator`]):
1699
+ The data collator to use for training. If None is specified, the default data collator
1700
+ ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
1701
+ sequences in the batch, given a dataset of paired sequences.
1702
+ train_dataset ([`~datasets.Dataset`]):
1703
+ The dataset to use for training.
1704
+ eval_dataset ([`~datasets.Dataset`]):
1705
+ The dataset to use for evaluation.
1706
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
1707
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1708
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1709
+ reuse the fine-tuned model.
1710
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1711
+ The model initializer to use for training. If None is specified, the default model initializer will be
1712
+ used.
1713
+ callbacks (`list[transformers.TrainerCallback]`):
1714
+ The callbacks to use for training.
1715
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1716
+ The optimizer and scheduler to use for training.
1717
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1718
+ The function to use to preprocess the logits before computing the metrics.
1719
+ peft_config (`dict`, defaults to `None`):
1720
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
1721
+ a PEFT model.
1722
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1723
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1724
+ metric values.
1725
+
1726
+ """
1727
+ def __init__(
1728
+ self,
1729
+ model = None,
1730
+ args = None,
1731
+ data_collator = None,
1732
+ train_dataset = None,
1733
+ eval_dataset = None,
1734
+ processing_class = None,
1735
+ model_init = None,
1736
+ callbacks = None,
1737
+ preprocess_logits_for_metrics = None,
1738
+ peft_config = None,
1739
+ compute_metrics = None,
1740
+ **kwargs
1741
+ ):
1742
+ if args is None: args = UnslothCPOConfig()
1743
+ use_bf16 = getattr(args, 'bf16', False)
1744
+ if type(use_bf16) is not bool: use_bf16 = False
1745
+ use_fp16 = getattr(args, 'fp16', False)
1746
+ if type(use_fp16) is not bool: use_fp16 = False
1747
+ force_float32 = False
1748
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1749
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1750
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1751
+ force_float32 = True
1752
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1753
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1754
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1755
+ from unsloth_zoo.utils import _get_dtype
1756
+ dtype = _get_dtype(dtype)
1757
+ float16 = dtype == torch.float16
1758
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1759
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1760
+ if force_float32:
1761
+ # Forced float32 training
1762
+ args.fp16 = False
1763
+ args.bf16 = False
1764
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1765
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1766
+ # args.mixed_precision is a new argument which needs to be set now
1767
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1768
+ # Mixed precision training
1769
+ args.fp16 = float16
1770
+ args.bf16 = not float16
1771
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1772
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1773
+ # args.mixed_precision is a new argument which needs to be set now
1774
+ elif mixed_precision_dtype == 'bfloat16':
1775
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1776
+ args.fp16 = False
1777
+ args.bf16 = False
1778
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1779
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1780
+ # args.mixed_precision is a new argument which needs to be set now
1781
+
1782
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1783
+ args.eval_strategy = 'steps'
1784
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1785
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1786
+ if ga_steps is not None and ga_steps > 1:
1787
+ from transformers import __version__ as transformers_version
1788
+ if Version(transformers_version) <= Version('4.45.2'):
1789
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1790
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1791
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1792
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1793
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1794
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1795
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1796
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1797
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1798
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1799
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1800
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1801
+ if force_float32:
1802
+ args.bf16_full_eval = False
1803
+ args.fp16_full_eval = False
1804
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1805
+ args.bf16_full_eval = True
1806
+ args.fp16_full_eval = False
1807
+ elif not bf16_full_eval and not fp16_full_eval:
1808
+ args.bf16_full_eval = args.bf16
1809
+ args.fp16_full_eval = args.fp16
1810
+ _output_logits = False
1811
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1812
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1813
+ if _output_logits:
1814
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1815
+ if model is not None:
1816
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1817
+ if _warnings_issued is None:
1818
+ model.warnings_issued = {}
1819
+ elif not isinstance(_warnings_issued, dict):
1820
+ try:
1821
+ model.warnings_issued = dict(_warnings_issued)
1822
+ except Exception:
1823
+ model.warnings_issued = {}
1824
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1825
+ pass
1826
+ else:
1827
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1828
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1829
+ if args_max_seq_length is None and model_max_seq_length is not None:
1830
+ max_seq_length = model.max_seq_length
1831
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1832
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1833
+ if args_max_seq_length > model_max_seq_length:
1834
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1835
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1836
+ args.max_seq_length = model_max_seq_length
1837
+ if model is not None and hasattr(model, 'for_training'):
1838
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1839
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1840
+ if 'processing_class' in locals():
1841
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1842
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1843
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1844
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1845
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1846
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1847
+ data_collator = TransformersDataCollatorForLanguageModeling(
1848
+ __tokenizer,
1849
+ mlm = False,
1850
+ mlm_probability = 0.0,
1851
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1852
+ )
1853
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1854
+ data_collator = DataCollatorForSeq2Seq(
1855
+ __tokenizer,
1856
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1857
+ )
1858
+ else:
1859
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1860
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1861
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1862
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1863
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1864
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1865
+ data_collator = DataCollatorForSeq2Seq(
1866
+ __tokenizer.tokenizer,
1867
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1868
+ )
1869
+ else:
1870
+ data_collator = TransformersDataCollatorForLanguageModeling(
1871
+ __tokenizer.tokenizer,
1872
+ mlm = False,
1873
+ mlm_probability = 0.0,
1874
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1875
+ )
1876
+ other_metrics = []
1877
+
1878
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1879
+ PatchRLStatistics('cpo_trainer', other_metrics)
1880
+
1881
+ # [TODO] Fix up DataParallel multiplying batch sizes
1882
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1883
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1884
+ if getattr(args, "_n_gpu", 1) != 1:
1885
+ args._n_gpu = 1
1886
+ if "model" in locals() and hasattr(model, "for_training"):
1887
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1888
+ super().__init__(
1889
+ model = model,
1890
+ args = args,
1891
+ data_collator = data_collator,
1892
+ train_dataset = train_dataset,
1893
+ eval_dataset = eval_dataset,
1894
+ processing_class = processing_class,
1895
+ model_init = model_init,
1896
+ callbacks = callbacks,
1897
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1898
+ peft_config = peft_config,
1899
+ compute_metrics = compute_metrics,**kwargs)
1900
+ if "model" in locals() and hasattr(model, "for_inference"):
1901
+ model.for_inference()
1902
+ if hasattr(self, 'neftune_hook_handle'):
1903
+ self.neftune_hook_handle.remove()
1904
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1905
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1906
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1907
+ pass
1908
+ if hasattr(self, 'accelerator'):
1909
+ scaler = self.accelerator.scaler
1910
+ current_model = model
1911
+ while hasattr(current_model, 'model'):
1912
+ current_model.accelerator_scaler = scaler
1913
+ current_model = current_model.model
1914
+ current_model.accelerator_scaler = scaler
1915
+ pass
1916
+ if hasattr(self, 'train'):
1917
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1918
+ pass
1919
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1920
+ _vllm_tok = self.llm.get_tokenizer()
1921
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1922
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1923
+ _vllm_tok.chat_template = _pc.chat_template
1924
+ pass
1925
+
1926
+ pass
1927
+
1928
+
1929
+ if hasattr(logger, "addFilter"):
1930
+ import logging
1931
+ class HideLoggingMessage(logging.Filter):
1932
+ def __init__(self, text): self.text = text
1933
+ def filter(self, x): return not (self.text in x.getMessage())
1934
+ pass
1935
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
1936
+
unsloth_compiled_cache/UnslothGKDTrainer.py ADDED
@@ -0,0 +1,1287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, nn, os, prepare_deepspeed, torch, warnings)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothGKDConfig(GKDConfig):
325
+ """
326
+
327
+ Configuration class for [`GKDTrainer`].
328
+
329
+ This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
330
+ please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
331
+
332
+ Args:
333
+ temperature (`float`, *optional*, defaults to `0.9`):
334
+ Temperature for sampling. The higher the temperature, the more random the completions.
335
+ lmbda (`float`, *optional*, defaults to `0.5`):
336
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
337
+ student-generated outputs).
338
+ beta (`float`, *optional*, defaults to `0.5`):
339
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
340
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
341
+ max_new_tokens (`int`, *optional*, defaults to `128`):
342
+ Maximum number of tokens to generate per completion.
343
+ teacher_model_name_or_path (`str`, *optional*):
344
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
345
+ trained.
346
+ teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
347
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
348
+ from a string.
349
+ disable_dropout (`bool`, *optional*, defaults to `True`):
350
+ Whether to disable dropout in the model.
351
+ seq_kd (`bool`, *optional*, defaults to `False`):
352
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
353
+ teacher-generated output).
354
+
355
+ """
356
+ vllm_sampling_params: Optional[Any] = field(
357
+ default = None,
358
+ metadata = {'help': 'vLLM SamplingParams'},
359
+ )
360
+ unsloth_num_chunks : Optional[int] = field(
361
+ default = -1,
362
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
363
+ )
364
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
365
+ default = None,
366
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
367
+ )
368
+ unsloth_grpo_mini_batch : Optional[int] = field(
369
+ default = None,
370
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
371
+ )
372
+ max_seq_length : Optional[int] = field(
373
+ default = None,
374
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
375
+ )
376
+ def __init__(
377
+ self,
378
+ output_dir = None,
379
+ per_device_train_batch_size = 4,
380
+ num_train_epochs = 3.0,
381
+ max_steps = -1,
382
+ learning_rate = 5e-05,
383
+ lr_scheduler_type = 'linear',
384
+ lr_scheduler_kwargs = None,
385
+ warmup_steps = 0.1,
386
+ optim = 'adamw_8bit',
387
+ optim_args = None,
388
+ weight_decay = 0.01,
389
+ adam_beta1 = 0.9,
390
+ adam_beta2 = 0.999,
391
+ adam_epsilon = 1e-08,
392
+ optim_target_modules = None,
393
+ gradient_accumulation_steps = 2,
394
+ average_tokens_across_devices = True,
395
+ max_grad_norm = 1.0,
396
+ label_smoothing_factor = 0.0,
397
+ bf16 = False,
398
+ fp16 = False,
399
+ bf16_full_eval = False,
400
+ fp16_full_eval = False,
401
+ tf32 = None,
402
+ gradient_checkpointing = True,
403
+ gradient_checkpointing_kwargs = None,
404
+ torch_compile = False,
405
+ torch_compile_backend = None,
406
+ torch_compile_mode = None,
407
+ use_liger_kernel = False,
408
+ liger_kernel_config = None,
409
+ use_cache = False,
410
+ neftune_noise_alpha = None,
411
+ torch_empty_cache_steps = 250,
412
+ auto_find_batch_size = False,
413
+ logging_strategy = 'steps',
414
+ logging_steps = 1,
415
+ logging_first_step = False,
416
+ log_on_each_node = True,
417
+ logging_nan_inf_filter = False,
418
+ include_num_input_tokens_seen = False,
419
+ log_level = 'passive',
420
+ log_level_replica = 'warning',
421
+ disable_tqdm = None,
422
+ report_to = 'none',
423
+ run_name = None,
424
+ project = 'huggingface',
425
+ trackio_space_id = 'trackio',
426
+ eval_strategy = 'no',
427
+ eval_steps = None,
428
+ eval_delay = 0,
429
+ per_device_eval_batch_size = 4,
430
+ prediction_loss_only = False,
431
+ eval_on_start = False,
432
+ eval_do_concat_batches = True,
433
+ eval_use_gather_object = False,
434
+ eval_accumulation_steps = 2,
435
+ batch_eval_metrics = False,
436
+ save_only_model = False,
437
+ save_strategy = 'steps',
438
+ save_steps = 500,
439
+ save_on_each_node = False,
440
+ save_total_limit = None,
441
+ enable_jit_checkpoint = False,
442
+ push_to_hub = False,
443
+ hub_token = None,
444
+ hub_private_repo = None,
445
+ hub_model_id = None,
446
+ hub_strategy = 'every_save',
447
+ hub_always_push = False,
448
+ hub_revision = None,
449
+ load_best_model_at_end = False,
450
+ metric_for_best_model = None,
451
+ greater_is_better = None,
452
+ ignore_data_skip = False,
453
+ restore_callback_states_from_checkpoint = False,
454
+ full_determinism = False,
455
+ seed = 3407,
456
+ data_seed = 3407,
457
+ use_cpu = False,
458
+ accelerator_config = None,
459
+ parallelism_config = None,
460
+ dataloader_drop_last = False,
461
+ dataloader_num_workers = 0,
462
+ dataloader_pin_memory = True,
463
+ dataloader_persistent_workers = False,
464
+ dataloader_prefetch_factor = None,
465
+ remove_unused_columns = True,
466
+ label_names = None,
467
+ train_sampling_strategy = 'random',
468
+ length_column_name = 'length',
469
+ ddp_find_unused_parameters = None,
470
+ ddp_bucket_cap_mb = None,
471
+ ddp_broadcast_buffers = None,
472
+ ddp_backend = None,
473
+ ddp_timeout = 1800,
474
+ fsdp = None,
475
+ fsdp_config = None,
476
+ deepspeed = None,
477
+ debug = '',
478
+ skip_memory_metrics = True,
479
+ do_train = False,
480
+ do_eval = False,
481
+ do_predict = False,
482
+ resume_from_checkpoint = None,
483
+ warmup_ratio = None,
484
+ logging_dir = None,
485
+ local_rank = -1,
486
+ model_init_kwargs = None,
487
+ chat_template_path = None,
488
+ dataset_text_field = 'text',
489
+ dataset_kwargs = None,
490
+ dataset_num_proc = None,
491
+ eos_token = None,
492
+ pad_token = None,
493
+ max_length = 1024,
494
+ packing = False,
495
+ packing_strategy = 'bfd',
496
+ padding_free = None,
497
+ pad_to_multiple_of = None,
498
+ eval_packing = None,
499
+ completion_only_loss = None,
500
+ assistant_only_loss = False,
501
+ loss_type = 'nll',
502
+ activation_offloading = False,
503
+ temperature = 0.9,
504
+ lmbda = 0.5,
505
+ beta = 0.5,
506
+ max_new_tokens = 128,
507
+ teacher_model_name_or_path = None,
508
+ teacher_model_init_kwargs = None,
509
+ disable_dropout = True,
510
+ seq_kd = False,
511
+ vllm_sampling_params = None,
512
+ unsloth_num_chunks = -1,
513
+ unsloth_logit_chunk_multiplier = None,
514
+ unsloth_grpo_mini_batch = None,
515
+ max_seq_length = None,
516
+ **kwargs,
517
+ ):
518
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
519
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
520
+ if num_train_epochs is None:
521
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
522
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
523
+ output_dir = 'unsloth_training_checkpoints'
524
+ save_strategy = 'no'
525
+ import multiprocessing as _mp
526
+ if _mp.get_start_method() != 'fork':
527
+ dataset_num_proc = None
528
+ elif dataset_num_proc is None:
529
+ import psutil
530
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
531
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
532
+ if memory_gb_left <= 2: dataset_num_proc = 1
533
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
534
+ if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
535
+ from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
536
+ if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
537
+ from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
538
+ pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
539
+
540
+ if temperature <= 0:
541
+ raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
542
+ elif temperature >= 10:
543
+ raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
544
+
545
+
546
+ super().__init__(
547
+ output_dir = output_dir,
548
+ per_device_train_batch_size = per_device_train_batch_size,
549
+ num_train_epochs = num_train_epochs,
550
+ max_steps = max_steps,
551
+ learning_rate = learning_rate,
552
+ lr_scheduler_type = lr_scheduler_type,
553
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
554
+ warmup_steps = warmup_steps,
555
+ optim = optim,
556
+ optim_args = optim_args,
557
+ weight_decay = weight_decay,
558
+ adam_beta1 = adam_beta1,
559
+ adam_beta2 = adam_beta2,
560
+ adam_epsilon = adam_epsilon,
561
+ optim_target_modules = optim_target_modules,
562
+ gradient_accumulation_steps = gradient_accumulation_steps,
563
+ average_tokens_across_devices = average_tokens_across_devices,
564
+ max_grad_norm = max_grad_norm,
565
+ label_smoothing_factor = label_smoothing_factor,
566
+ bf16 = bf16,
567
+ fp16 = fp16,
568
+ bf16_full_eval = bf16_full_eval,
569
+ fp16_full_eval = fp16_full_eval,
570
+ tf32 = tf32,
571
+ gradient_checkpointing = gradient_checkpointing,
572
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
573
+ torch_compile = torch_compile,
574
+ torch_compile_backend = torch_compile_backend,
575
+ torch_compile_mode = torch_compile_mode,
576
+ use_liger_kernel = use_liger_kernel,
577
+ liger_kernel_config = liger_kernel_config,
578
+ use_cache = use_cache,
579
+ neftune_noise_alpha = neftune_noise_alpha,
580
+ torch_empty_cache_steps = torch_empty_cache_steps,
581
+ auto_find_batch_size = auto_find_batch_size,
582
+ logging_strategy = logging_strategy,
583
+ logging_steps = logging_steps,
584
+ logging_first_step = logging_first_step,
585
+ log_on_each_node = log_on_each_node,
586
+ logging_nan_inf_filter = logging_nan_inf_filter,
587
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
588
+ log_level = log_level,
589
+ log_level_replica = log_level_replica,
590
+ disable_tqdm = disable_tqdm,
591
+ report_to = report_to,
592
+ run_name = run_name,
593
+ project = project,
594
+ trackio_space_id = trackio_space_id,
595
+ eval_strategy = eval_strategy,
596
+ eval_steps = eval_steps,
597
+ eval_delay = eval_delay,
598
+ per_device_eval_batch_size = per_device_eval_batch_size,
599
+ prediction_loss_only = prediction_loss_only,
600
+ eval_on_start = eval_on_start,
601
+ eval_do_concat_batches = eval_do_concat_batches,
602
+ eval_use_gather_object = eval_use_gather_object,
603
+ eval_accumulation_steps = eval_accumulation_steps,
604
+ batch_eval_metrics = batch_eval_metrics,
605
+ save_only_model = save_only_model,
606
+ save_strategy = save_strategy,
607
+ save_steps = save_steps,
608
+ save_on_each_node = save_on_each_node,
609
+ save_total_limit = save_total_limit,
610
+ enable_jit_checkpoint = enable_jit_checkpoint,
611
+ push_to_hub = push_to_hub,
612
+ hub_token = hub_token,
613
+ hub_private_repo = hub_private_repo,
614
+ hub_model_id = hub_model_id,
615
+ hub_strategy = hub_strategy,
616
+ hub_always_push = hub_always_push,
617
+ hub_revision = hub_revision,
618
+ load_best_model_at_end = load_best_model_at_end,
619
+ metric_for_best_model = metric_for_best_model,
620
+ greater_is_better = greater_is_better,
621
+ ignore_data_skip = ignore_data_skip,
622
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
623
+ full_determinism = full_determinism,
624
+ seed = seed,
625
+ data_seed = data_seed,
626
+ use_cpu = use_cpu,
627
+ accelerator_config = accelerator_config,
628
+ parallelism_config = parallelism_config,
629
+ dataloader_drop_last = dataloader_drop_last,
630
+ dataloader_num_workers = dataloader_num_workers,
631
+ dataloader_pin_memory = dataloader_pin_memory,
632
+ dataloader_persistent_workers = dataloader_persistent_workers,
633
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
634
+ remove_unused_columns = remove_unused_columns,
635
+ label_names = label_names,
636
+ train_sampling_strategy = train_sampling_strategy,
637
+ length_column_name = length_column_name,
638
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
639
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
640
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
641
+ ddp_backend = ddp_backend,
642
+ ddp_timeout = ddp_timeout,
643
+ fsdp = fsdp,
644
+ fsdp_config = fsdp_config,
645
+ deepspeed = deepspeed,
646
+ debug = debug,
647
+ skip_memory_metrics = skip_memory_metrics,
648
+ do_train = do_train,
649
+ do_eval = do_eval,
650
+ do_predict = do_predict,
651
+ resume_from_checkpoint = resume_from_checkpoint,
652
+ warmup_ratio = warmup_ratio,
653
+ logging_dir = logging_dir,
654
+ local_rank = local_rank,
655
+ model_init_kwargs = model_init_kwargs,
656
+ chat_template_path = chat_template_path,
657
+ dataset_text_field = dataset_text_field,
658
+ dataset_kwargs = dataset_kwargs,
659
+ dataset_num_proc = dataset_num_proc,
660
+ eos_token = eos_token,
661
+ pad_token = pad_token,
662
+ max_length = max_length,
663
+ packing = packing,
664
+ packing_strategy = packing_strategy,
665
+ padding_free = padding_free,
666
+ pad_to_multiple_of = pad_to_multiple_of,
667
+ eval_packing = eval_packing,
668
+ completion_only_loss = completion_only_loss,
669
+ assistant_only_loss = assistant_only_loss,
670
+ loss_type = loss_type,
671
+ activation_offloading = activation_offloading,
672
+ temperature = temperature,
673
+ lmbda = lmbda,
674
+ beta = beta,
675
+ max_new_tokens = max_new_tokens,
676
+ teacher_model_name_or_path = teacher_model_name_or_path,
677
+ teacher_model_init_kwargs = teacher_model_init_kwargs,
678
+ disable_dropout = disable_dropout,
679
+ seq_kd = seq_kd,**kwargs)
680
+ self.vllm_sampling_params = vllm_sampling_params
681
+ self.unsloth_num_chunks = unsloth_num_chunks
682
+ if unsloth_grpo_mini_batch is not None:
683
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
684
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
685
+ else:
686
+ raise ValueError(
687
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
688
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
689
+ )
690
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
691
+ self.max_seq_length = max_seq_length
692
+
693
+ pass
694
+
695
+ class _UnslothGKDTrainer(SFTTrainer):
696
+ """"""
697
+
698
+ _tag_names = ["trl", "gkd"]
699
+ _name = "GKD"
700
+ _paper = {
701
+ "title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
702
+ "id": "2306.13649",
703
+ # docstyle-ignore
704
+ "citation": textwrap.dedent("""\
705
+ @inproceedings{agarwal2024on-policy,
706
+ title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
707
+ author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
708
+ year = 2024,
709
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
710
+ publisher = {OpenReview.net},
711
+ url = {https://openreview.net/forum?id=3zKtaqxLhW},
712
+ }"""),
713
+ }
714
+
715
+ def __init__(
716
+ self,
717
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
718
+ teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
719
+ args: Optional[GKDConfig] = None,
720
+ data_collator: Optional[DataCollator] = None, # type: ignore
721
+ train_dataset: Optional[Dataset] = None,
722
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
723
+ processing_class: Optional[
724
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
725
+ ] = None,
726
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
727
+ callbacks: Optional[list[TrainerCallback]] = None,
728
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
729
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
730
+ peft_config: Optional["PeftConfig"] = None,
731
+ formatting_func: Optional[Callable] = None,
732
+ ):
733
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
734
+ warnings.warn(
735
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
736
+ "it and want it to remain, please share your comments here: "
737
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
738
+ "TRL_EXPERIMENTAL_SILENCE=1."
739
+ )
740
+ # Ensure Trainer does not drop non-signature columns used by the collator [e.g., "prompts"]
741
+ args.remove_unused_columns = False
742
+ # Respect a user-provided data_collator; otherwise, provide a ChatML collator that
743
+ if data_collator is None:
744
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
745
+
746
+ # Ensure SFTTrainer does not pre-process the dataset when using a ChatML collator,
747
+ # so that raw conversational fields [e.g., "messages"] remain available to the collator.
748
+ if args.dataset_kwargs is None:
749
+ args.dataset_kwargs = {"skip_prepare_dataset": True}
750
+ else:
751
+ args.dataset_kwargs["skip_prepare_dataset"] = True
752
+
753
+ # Liger fused GKD loss [JSD]
754
+ self.use_liger_gkd_loss = False
755
+ if args.use_liger_kernel:
756
+ self.liger_jsd_loss = LigerFusedLinearJSDLoss(
757
+ beta=args.beta,
758
+ ignore_index=-100,
759
+ temperature=args.temperature,
760
+ compiled=False,
761
+ )
762
+ self.use_liger_gkd_loss = True
763
+
764
+ super().__init__(
765
+ model,
766
+ args=args,
767
+ data_collator=data_collator,
768
+ train_dataset=train_dataset,
769
+ eval_dataset=eval_dataset,
770
+ processing_class=processing_class,
771
+ compute_metrics=compute_metrics,
772
+ callbacks=callbacks,
773
+ optimizers=optimizers,
774
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
775
+ peft_config=peft_config,
776
+ formatting_func=formatting_func,
777
+ )
778
+
779
+ if args.teacher_model_init_kwargs is None:
780
+ teacher_model_init_kwargs = {}
781
+ elif not isinstance(teacher_model, str):
782
+ raise ValueError(
783
+ "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
784
+ )
785
+ else:
786
+ teacher_model_init_kwargs = args.teacher_model_init_kwargs
787
+ teacher_model_init_kwargs["dtype"] = (
788
+ teacher_model_init_kwargs["dtype"]
789
+ if teacher_model_init_kwargs["dtype"] in ["auto", None]
790
+ else getattr(torch, teacher_model_init_kwargs["dtype"])
791
+ )
792
+
793
+ if isinstance(teacher_model, str):
794
+ teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
795
+
796
+ # Disable dropout in the model
797
+ if args.disable_dropout:
798
+ disable_dropout_in_model(self.model)
799
+
800
+ if self.is_deepspeed_enabled:
801
+ self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
802
+ else:
803
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
804
+
805
+ self.lmbda = args.lmbda
806
+ self.beta = args.beta
807
+ self.temperature = args.temperature
808
+ self.seq_kd = args.seq_kd
809
+
810
+ self.generation_config = GenerationConfig(
811
+ max_new_tokens=args.max_new_tokens,
812
+ temperature=args.temperature,
813
+ do_sample=True,
814
+ top_k=0,
815
+ use_cache=False if args.gradient_checkpointing else True,
816
+ pad_token_id=self.processing_class.pad_token_id,
817
+ )
818
+ # Set custom EOS tokens if they are specified by the model's generation
819
+ # config. This is important for models with the Llama 3 chat template,
820
+ # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
821
+ # turns or messages.
822
+ if (
823
+ hasattr(self.model.generation_config, "eos_token_id")
824
+ and self.model.generation_config.eos_token_id is not None
825
+ ):
826
+ self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
827
+
828
+ @staticmethod
829
+ def generalized_jsd_loss(
830
+ student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
831
+ ):
832
+ """
833
+ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
834
+ of https://huggingface.co/papers/2306.13649 for the definition.
835
+
836
+ Args:
837
+ student_logits:
838
+ Tensor of shape (batch_size, sequence_length, vocab_size)
839
+ teacher_logits:
840
+ Tensor of shape (batch_size, sequence_length, vocab_size)
841
+ labels:
842
+ Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing
843
+ loss
844
+ beta:
845
+ Interpolation coefficient between 0 and 1 (default: 0.5)
846
+ temperature:
847
+ Softmax temperature (default: 1.0)
848
+ reduction:
849
+ Specifies the reduction to apply to the output (default: 'batchmean')
850
+
851
+ Returns:
852
+ loss: Scalar tensor with the generalized JSD loss
853
+ """
854
+
855
+ # Apply temperature scaling
856
+ student_logits = student_logits / temperature
857
+ teacher_logits = teacher_logits / temperature
858
+
859
+ # Compute log probabilities for student and probabilities for teacher
860
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
861
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
862
+
863
+ if beta == 0:
864
+ jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
865
+ elif beta == 1:
866
+ jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
867
+ else:
868
+ # Compute the log of the mixture distribution
869
+ # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
870
+ beta = torch.tensor(beta, dtype=student_log_probs.dtype)
871
+ mixture_log_probs = torch.logsumexp(
872
+ torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
873
+ dim=0,
874
+ )
875
+
876
+ # Compute KL divergences using F.kl_div
877
+ # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
878
+ kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
879
+ kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
880
+
881
+ # Compute the Generalized Jensen-Shannon Divergence
882
+ jsd = beta * kl_teacher + (1 - beta) * kl_student
883
+
884
+ # Masking
885
+ if labels is not None:
886
+ mask = labels != -100
887
+ jsd = jsd[mask]
888
+
889
+ # Apply reduction
890
+ if reduction == "batchmean":
891
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0)
892
+ elif reduction == "sum":
893
+ return jsd.sum()
894
+ elif reduction == "mean":
895
+ return jsd.mean()
896
+ else:
897
+ return jsd
898
+
899
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
900
+ if self.use_liger_gkd_loss:
901
+ # Forward only through the base models (avoid lm_head to save memory)
902
+ unwrapped_student = self.accelerator.unwrap_model(model)
903
+ if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None:
904
+ base_student = unwrapped_student.get_decoder()
905
+ else:
906
+ base_student = getattr(
907
+ unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student
908
+ )
909
+
910
+ student_outputs = base_student(
911
+ input_ids=inputs["input_ids"],
912
+ attention_mask=inputs["attention_mask"],
913
+ output_hidden_states=True,
914
+ use_cache=False,
915
+ )
916
+
917
+ self.teacher_model.eval()
918
+ unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model)
919
+ if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None:
920
+ base_teacher = unwrapped_teacher.get_decoder()
921
+ else:
922
+ base_teacher = getattr(
923
+ unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher
924
+ )
925
+ with torch.no_grad():
926
+ teacher_outputs = base_teacher(
927
+ input_ids=inputs["input_ids"],
928
+ attention_mask=inputs["attention_mask"],
929
+ output_hidden_states=True,
930
+ use_cache=False,
931
+ )
932
+
933
+ # hidden states (shifted)
934
+ student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous()
935
+ teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous()
936
+
937
+ # labels mask and labels (shifted)
938
+ labels_mask = inputs["labels"] != -100
939
+ masked_input_ids = torch.where(
940
+ labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100)
941
+ )
942
+ true_labels = masked_input_ids[:, 1:].contiguous()
943
+
944
+ # heads
945
+ student_head = unwrapped_student.get_output_embeddings()
946
+ teacher_head = unwrapped_teacher.get_output_embeddings()
947
+
948
+ # liger fused jsd loss
949
+ loss = self.liger_jsd_loss(
950
+ student_input=student_hidden,
951
+ student_weight=student_head.weight,
952
+ teacher_input=teacher_hidden,
953
+ teacher_weight=teacher_head.weight,
954
+ true_labels=true_labels,
955
+ student_bias=getattr(student_head, "bias", None),
956
+ teacher_bias=getattr(teacher_head, "bias", None),
957
+ )
958
+ else:
959
+ # compute student output
960
+ student_outputs = model(
961
+ input_ids=inputs["input_ids"],
962
+ attention_mask=inputs["attention_mask"],
963
+ )
964
+
965
+ # compute teacher output in eval mode
966
+ self.teacher_model.eval()
967
+ with torch.no_grad():
968
+ teacher_outputs = self.teacher_model(
969
+ input_ids=inputs["input_ids"],
970
+ attention_mask=inputs["attention_mask"],
971
+ )
972
+
973
+ # slice the logits for the generated tokens using the inputs["prompts"] lengths
974
+ prompt_lengths = inputs["prompts"].shape[1]
975
+ shifted_student_logits = student_outputs.logits[:, prompt_lengths - 1 : -1, :]
976
+ shifted_teacher_logits = teacher_outputs.logits[:, prompt_lengths - 1 : -1, :]
977
+ shifted_labels = inputs["labels"][:, prompt_lengths:]
978
+
979
+ # compute loss
980
+ loss = self.generalized_jsd_loss(
981
+ student_logits=shifted_student_logits,
982
+ teacher_logits=shifted_teacher_logits,
983
+ labels=shifted_labels,
984
+ beta=self.beta,
985
+ )
986
+
987
+ # empty cache
988
+ empty_cache()
989
+
990
+ # Return loss
991
+ return (loss, student_outputs) if return_outputs else loss
992
+
993
+ @staticmethod
994
+ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
995
+ # Generate output with respect to the prompt-only
996
+ generated_outputs = model.generate(
997
+ input_ids=inputs["prompts"],
998
+ attention_mask=inputs.get("prompt_attention_mask", None),
999
+ generation_config=generation_config,
1000
+ return_dict_in_generate=True,
1001
+ )
1002
+
1003
+ # Get the generated token IDs
1004
+ generated_tokens = generated_outputs.sequences
1005
+ # Calculate new attention mask
1006
+ new_attention_mask = torch.ones_like(generated_tokens)
1007
+ new_labels = generated_tokens.clone()
1008
+
1009
+ # If there's pad_token_id, set attention mask to 0 for padding tokens
1010
+ if pad_token_id is not None:
1011
+ new_labels[new_labels == pad_token_id] = -100
1012
+ new_attention_mask[generated_tokens == pad_token_id] = 0
1013
+
1014
+ return generated_tokens, new_attention_mask, new_labels
1015
+
1016
+ def training_step(
1017
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
1018
+ ) -> torch.Tensor:
1019
+ """
1020
+ Perform a training step for the Generalized Knowledge Distillation (GKD) model.
1021
+
1022
+ This method implements the on-policy learning approach described in the GKD paper. With probability
1023
+ `self.lmbda`, it generates new responses using the student model, which are then used for training instead of
1024
+ the original inputs.
1025
+ """
1026
+ if self.seq_kd:
1027
+ with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
1028
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
1029
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
1030
+ )
1031
+ inputs["input_ids"] = new_input_ids
1032
+ inputs["attention_mask"] = new_attention_mask
1033
+ inputs["labels"] = new_labels
1034
+ if random.random() <= self.lmbda:
1035
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
1036
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
1037
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
1038
+ )
1039
+ inputs["input_ids"] = new_input_ids
1040
+ inputs["attention_mask"] = new_attention_mask
1041
+ inputs["labels"] = new_labels
1042
+
1043
+ loss = super().training_step(model, inputs, num_items_in_batch)
1044
+ return loss
1045
+ class UnslothGKDTrainer(_UnslothGKDTrainer):
1046
+ """
1047
+ Trainer for Generalized Knowledge Distillation (GKD) of language models.
1048
+
1049
+ For details on GKD, see the paper: [On-Policy Distillation of Language Models: Learning from Self-Generated
1050
+ Mistakes](https://huggingface.co/papers/2306.13649).
1051
+
1052
+ Args:
1053
+ model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
1054
+ Model to be trained, or the string identifier of the model to be instantiated from a pretrained model.
1055
+ teacher_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
1056
+ Teacher model for knowledge distillation, or the string identifier of the model to be instantiated from a
1057
+ pretrained model.
1058
+ args ([`GKDConfig`], *optional*):
1059
+ Training arguments.
1060
+ data_collator ([`~transformers.DataCollator`], *optional*):
1061
+ Data collator to batch samples from the dataset. It defaults to a [`DataCollatorForChatML`] using the
1062
+ `processing_class`.
1063
+ train_dataset ([`~datasets.Dataset`], *optional*):
1064
+ Dataset for training.
1065
+ eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
1066
+ Dataset for evaluation.
1067
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
1068
+ Class to process the data.
1069
+ compute_metrics (`Callable`, *optional*):
1070
+ Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a
1071
+ dictionary string to float.
1072
+ callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
1073
+ Callbacks to use during training.
1074
+ optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
1075
+ Tuple containing the optimizer and the learning rate scheduler to use for training.
1076
+ preprocess_logits_for_metrics (`Callable`, *optional*):
1077
+ Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and
1078
+ return the logits to be used for metrics computation.
1079
+ peft_config ([`~peft.PeftConfig`], *optional*):
1080
+ PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
1081
+ wrapped with the specified PEFT adapter.
1082
+ formatting_func (`Callable`, *optional*):
1083
+ Function to format the dataset. Must take in an example and return an example.
1084
+
1085
+ """
1086
+ def __init__(
1087
+ self,
1088
+ model = None,
1089
+ teacher_model = None,
1090
+ args = None,
1091
+ data_collator = None,
1092
+ train_dataset = None,
1093
+ eval_dataset = None,
1094
+ processing_class = None,
1095
+ compute_metrics = None,
1096
+ callbacks = None,
1097
+ preprocess_logits_for_metrics = None,
1098
+ peft_config = None,
1099
+ formatting_func = None,
1100
+ **kwargs
1101
+ ):
1102
+ if args is None: args = UnslothGKDConfig()
1103
+ use_bf16 = getattr(args, 'bf16', False)
1104
+ if type(use_bf16) is not bool: use_bf16 = False
1105
+ use_fp16 = getattr(args, 'fp16', False)
1106
+ if type(use_fp16) is not bool: use_fp16 = False
1107
+ force_float32 = False
1108
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1109
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1110
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1111
+ force_float32 = True
1112
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1113
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1114
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1115
+ from unsloth_zoo.utils import _get_dtype
1116
+ dtype = _get_dtype(dtype)
1117
+ float16 = dtype == torch.float16
1118
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1119
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1120
+ if force_float32:
1121
+ # Forced float32 training
1122
+ args.fp16 = False
1123
+ args.bf16 = False
1124
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1125
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1126
+ # args.mixed_precision is a new argument which needs to be set now
1127
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1128
+ # Mixed precision training
1129
+ args.fp16 = float16
1130
+ args.bf16 = not float16
1131
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1132
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1133
+ # args.mixed_precision is a new argument which needs to be set now
1134
+ elif mixed_precision_dtype == 'bfloat16':
1135
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1136
+ args.fp16 = False
1137
+ args.bf16 = False
1138
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1139
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1140
+ # args.mixed_precision is a new argument which needs to be set now
1141
+
1142
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1143
+ args.eval_strategy = 'steps'
1144
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1145
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1146
+ if ga_steps is not None and ga_steps > 1:
1147
+ from transformers import __version__ as transformers_version
1148
+ if Version(transformers_version) <= Version('4.45.2'):
1149
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1150
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1151
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1152
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1153
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1154
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1155
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1156
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1157
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1158
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1159
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1160
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1161
+ if force_float32:
1162
+ args.bf16_full_eval = False
1163
+ args.fp16_full_eval = False
1164
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1165
+ args.bf16_full_eval = True
1166
+ args.fp16_full_eval = False
1167
+ elif not bf16_full_eval and not fp16_full_eval:
1168
+ args.bf16_full_eval = args.bf16
1169
+ args.fp16_full_eval = args.fp16
1170
+ _output_logits = False
1171
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1172
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1173
+ if _output_logits:
1174
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1175
+ if model is not None:
1176
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1177
+ if _warnings_issued is None:
1178
+ model.warnings_issued = {}
1179
+ elif not isinstance(_warnings_issued, dict):
1180
+ try:
1181
+ model.warnings_issued = dict(_warnings_issued)
1182
+ except Exception:
1183
+ model.warnings_issued = {}
1184
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1185
+ pass
1186
+ else:
1187
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1188
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1189
+ if args_max_seq_length is None and model_max_seq_length is not None:
1190
+ max_seq_length = model.max_seq_length
1191
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1192
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1193
+ if args_max_seq_length > model_max_seq_length:
1194
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1195
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1196
+ args.max_seq_length = model_max_seq_length
1197
+ if model is not None and hasattr(model, 'for_training'):
1198
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1199
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1200
+ if 'processing_class' in locals():
1201
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1202
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1203
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1204
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1205
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1206
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1207
+ data_collator = TransformersDataCollatorForLanguageModeling(
1208
+ __tokenizer,
1209
+ mlm = False,
1210
+ mlm_probability = 0.0,
1211
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1212
+ )
1213
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1214
+ data_collator = DataCollatorForSeq2Seq(
1215
+ __tokenizer,
1216
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1217
+ )
1218
+ else:
1219
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1220
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1221
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1222
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1223
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1224
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1225
+ data_collator = DataCollatorForSeq2Seq(
1226
+ __tokenizer.tokenizer,
1227
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1228
+ )
1229
+ else:
1230
+ data_collator = TransformersDataCollatorForLanguageModeling(
1231
+ __tokenizer.tokenizer,
1232
+ mlm = False,
1233
+ mlm_probability = 0.0,
1234
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1235
+ )
1236
+ other_metrics = []
1237
+
1238
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1239
+ PatchRLStatistics('gkd_trainer', other_metrics)
1240
+
1241
+ # [TODO] Fix up DataParallel multiplying batch sizes
1242
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1243
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1244
+ if getattr(args, "_n_gpu", 1) != 1:
1245
+ args._n_gpu = 1
1246
+ if "model" in locals() and hasattr(model, "for_training"):
1247
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1248
+ super().__init__(
1249
+ model = model,
1250
+ teacher_model = teacher_model,
1251
+ args = args,
1252
+ data_collator = data_collator,
1253
+ train_dataset = train_dataset,
1254
+ eval_dataset = eval_dataset,
1255
+ processing_class = processing_class,
1256
+ compute_metrics = compute_metrics,
1257
+ callbacks = callbacks,
1258
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1259
+ peft_config = peft_config,
1260
+ formatting_func = formatting_func,**kwargs)
1261
+ if "model" in locals() and hasattr(model, "for_inference"):
1262
+ model.for_inference()
1263
+ if hasattr(self, 'neftune_hook_handle'):
1264
+ self.neftune_hook_handle.remove()
1265
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1266
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1267
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1268
+ pass
1269
+ if hasattr(self, 'accelerator'):
1270
+ scaler = self.accelerator.scaler
1271
+ current_model = model
1272
+ while hasattr(current_model, 'model'):
1273
+ current_model.accelerator_scaler = scaler
1274
+ current_model = current_model.model
1275
+ current_model.accelerator_scaler = scaler
1276
+ pass
1277
+ if hasattr(self, 'train'):
1278
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1279
+ pass
1280
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1281
+ _vllm_tok = self.llm.get_tokenizer()
1282
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1283
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1284
+ _vllm_tok.chat_template = _pc.chat_template
1285
+ pass
1286
+
1287
+ pass
unsloth_compiled_cache/UnslothKTOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothORPOTrainer.py ADDED
@@ -0,0 +1,1860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothORPOConfig(ORPOConfig):
325
+ """
326
+
327
+ Configuration class for the [`ORPOTrainer`].
328
+
329
+ This class includes only the parameters that are specific to ORPO training. For a full list of training arguments,
330
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
331
+ differ from those in [`~transformers.TrainingArguments`].
332
+
333
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
334
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
335
+ command line.
336
+
337
+ Parameters:
338
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
339
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
340
+ to use the default data collator.
341
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
342
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
343
+ max_completion_length (`int`, *optional*):
344
+ Maximum length of the completion. This argument is required if you want to use the default data collator
345
+ and your model is an encoder-decoder.
346
+ beta (`float`, *optional*, defaults to `0.1`):
347
+ Parameter controlling the relative ratio loss weight in the ORPO loss. In the
348
+ [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the
349
+ [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
350
+ disable_dropout (`bool`, *optional*, defaults to `True`):
351
+ Whether to disable dropout in the model.
352
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
353
+ Label pad token id. This argument is required if you want to use the default data collator.
354
+ padding_value (`int`, *optional*):
355
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
356
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
357
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
358
+ This argument is required if you want to use the default data collator.
359
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
360
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
361
+ is_encoder_decoder (`bool`, *optional*):
362
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
363
+ you need to specify if the model returned by the callable is an encoder-decoder model.
364
+ model_init_kwargs (`dict[str, Any]`, *optional*):
365
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
366
+ string.
367
+ dataset_num_proc (`int`, *optional*):
368
+ Number of processes to use for processing the dataset.
369
+
370
+ """
371
+ vllm_sampling_params: Optional[Any] = field(
372
+ default = None,
373
+ metadata = {'help': 'vLLM SamplingParams'},
374
+ )
375
+ unsloth_num_chunks : Optional[int] = field(
376
+ default = -1,
377
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
378
+ )
379
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
380
+ default = None,
381
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
382
+ )
383
+ unsloth_grpo_mini_batch : Optional[int] = field(
384
+ default = None,
385
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
386
+ )
387
+ max_seq_length : Optional[int] = field(
388
+ default = None,
389
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
390
+ )
391
+ def __init__(
392
+ self,
393
+ output_dir = None,
394
+ per_device_train_batch_size = 4,
395
+ num_train_epochs = 3.0,
396
+ max_steps = -1,
397
+ learning_rate = 5e-05,
398
+ lr_scheduler_type = 'linear',
399
+ lr_scheduler_kwargs = None,
400
+ warmup_steps = 0.1,
401
+ optim = 'adamw_8bit',
402
+ optim_args = None,
403
+ weight_decay = 0.01,
404
+ adam_beta1 = 0.9,
405
+ adam_beta2 = 0.999,
406
+ adam_epsilon = 1e-08,
407
+ optim_target_modules = None,
408
+ gradient_accumulation_steps = 2,
409
+ average_tokens_across_devices = True,
410
+ max_grad_norm = 1.0,
411
+ label_smoothing_factor = 0.0,
412
+ bf16 = False,
413
+ fp16 = False,
414
+ bf16_full_eval = False,
415
+ fp16_full_eval = False,
416
+ tf32 = None,
417
+ gradient_checkpointing = True,
418
+ gradient_checkpointing_kwargs = None,
419
+ torch_compile = False,
420
+ torch_compile_backend = None,
421
+ torch_compile_mode = None,
422
+ use_liger_kernel = False,
423
+ liger_kernel_config = None,
424
+ use_cache = False,
425
+ neftune_noise_alpha = None,
426
+ torch_empty_cache_steps = 250,
427
+ auto_find_batch_size = False,
428
+ logging_strategy = 'steps',
429
+ logging_steps = 1,
430
+ logging_first_step = False,
431
+ log_on_each_node = True,
432
+ logging_nan_inf_filter = False,
433
+ include_num_input_tokens_seen = False,
434
+ log_level = 'passive',
435
+ log_level_replica = 'warning',
436
+ disable_tqdm = None,
437
+ report_to = 'none',
438
+ run_name = None,
439
+ project = 'huggingface',
440
+ trackio_space_id = 'trackio',
441
+ eval_strategy = 'no',
442
+ eval_steps = None,
443
+ eval_delay = 0,
444
+ per_device_eval_batch_size = 4,
445
+ prediction_loss_only = False,
446
+ eval_on_start = False,
447
+ eval_do_concat_batches = True,
448
+ eval_use_gather_object = False,
449
+ eval_accumulation_steps = 2,
450
+ batch_eval_metrics = False,
451
+ save_only_model = False,
452
+ save_strategy = 'steps',
453
+ save_steps = 500,
454
+ save_on_each_node = False,
455
+ save_total_limit = None,
456
+ enable_jit_checkpoint = False,
457
+ push_to_hub = False,
458
+ hub_token = None,
459
+ hub_private_repo = None,
460
+ hub_model_id = None,
461
+ hub_strategy = 'every_save',
462
+ hub_always_push = False,
463
+ hub_revision = None,
464
+ load_best_model_at_end = False,
465
+ metric_for_best_model = None,
466
+ greater_is_better = None,
467
+ ignore_data_skip = False,
468
+ restore_callback_states_from_checkpoint = False,
469
+ full_determinism = False,
470
+ seed = 3407,
471
+ data_seed = 3407,
472
+ use_cpu = False,
473
+ accelerator_config = None,
474
+ parallelism_config = None,
475
+ dataloader_drop_last = False,
476
+ dataloader_num_workers = 0,
477
+ dataloader_pin_memory = True,
478
+ dataloader_persistent_workers = False,
479
+ dataloader_prefetch_factor = None,
480
+ remove_unused_columns = True,
481
+ label_names = None,
482
+ train_sampling_strategy = 'random',
483
+ length_column_name = 'length',
484
+ ddp_find_unused_parameters = None,
485
+ ddp_bucket_cap_mb = None,
486
+ ddp_broadcast_buffers = None,
487
+ ddp_backend = None,
488
+ ddp_timeout = 1800,
489
+ fsdp = None,
490
+ fsdp_config = None,
491
+ deepspeed = None,
492
+ debug = '',
493
+ skip_memory_metrics = True,
494
+ do_train = False,
495
+ do_eval = False,
496
+ do_predict = False,
497
+ resume_from_checkpoint = None,
498
+ warmup_ratio = None,
499
+ logging_dir = None,
500
+ local_rank = -1,
501
+ max_length = 1024,
502
+ max_prompt_length = 512,
503
+ max_completion_length = None,
504
+ beta = 0.1,
505
+ disable_dropout = True,
506
+ label_pad_token_id = -100,
507
+ padding_value = None,
508
+ truncation_mode = 'keep_end',
509
+ generate_during_eval = False,
510
+ is_encoder_decoder = None,
511
+ model_init_kwargs = None,
512
+ dataset_num_proc = None,
513
+ vllm_sampling_params = None,
514
+ unsloth_num_chunks = -1,
515
+ unsloth_logit_chunk_multiplier = None,
516
+ unsloth_grpo_mini_batch = None,
517
+ max_seq_length = None,
518
+ **kwargs,
519
+ ):
520
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
521
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
522
+ if num_train_epochs is None:
523
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
524
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
525
+ output_dir = 'unsloth_training_checkpoints'
526
+ save_strategy = 'no'
527
+ import multiprocessing as _mp
528
+ if _mp.get_start_method() != 'fork':
529
+ dataset_num_proc = None
530
+ elif dataset_num_proc is None:
531
+ import psutil
532
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
533
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
534
+ if memory_gb_left <= 2: dataset_num_proc = 1
535
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
536
+
537
+ super().__init__(
538
+ output_dir = output_dir,
539
+ per_device_train_batch_size = per_device_train_batch_size,
540
+ num_train_epochs = num_train_epochs,
541
+ max_steps = max_steps,
542
+ learning_rate = learning_rate,
543
+ lr_scheduler_type = lr_scheduler_type,
544
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
545
+ warmup_steps = warmup_steps,
546
+ optim = optim,
547
+ optim_args = optim_args,
548
+ weight_decay = weight_decay,
549
+ adam_beta1 = adam_beta1,
550
+ adam_beta2 = adam_beta2,
551
+ adam_epsilon = adam_epsilon,
552
+ optim_target_modules = optim_target_modules,
553
+ gradient_accumulation_steps = gradient_accumulation_steps,
554
+ average_tokens_across_devices = average_tokens_across_devices,
555
+ max_grad_norm = max_grad_norm,
556
+ label_smoothing_factor = label_smoothing_factor,
557
+ bf16 = bf16,
558
+ fp16 = fp16,
559
+ bf16_full_eval = bf16_full_eval,
560
+ fp16_full_eval = fp16_full_eval,
561
+ tf32 = tf32,
562
+ gradient_checkpointing = gradient_checkpointing,
563
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
564
+ torch_compile = torch_compile,
565
+ torch_compile_backend = torch_compile_backend,
566
+ torch_compile_mode = torch_compile_mode,
567
+ use_liger_kernel = use_liger_kernel,
568
+ liger_kernel_config = liger_kernel_config,
569
+ use_cache = use_cache,
570
+ neftune_noise_alpha = neftune_noise_alpha,
571
+ torch_empty_cache_steps = torch_empty_cache_steps,
572
+ auto_find_batch_size = auto_find_batch_size,
573
+ logging_strategy = logging_strategy,
574
+ logging_steps = logging_steps,
575
+ logging_first_step = logging_first_step,
576
+ log_on_each_node = log_on_each_node,
577
+ logging_nan_inf_filter = logging_nan_inf_filter,
578
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
579
+ log_level = log_level,
580
+ log_level_replica = log_level_replica,
581
+ disable_tqdm = disable_tqdm,
582
+ report_to = report_to,
583
+ run_name = run_name,
584
+ project = project,
585
+ trackio_space_id = trackio_space_id,
586
+ eval_strategy = eval_strategy,
587
+ eval_steps = eval_steps,
588
+ eval_delay = eval_delay,
589
+ per_device_eval_batch_size = per_device_eval_batch_size,
590
+ prediction_loss_only = prediction_loss_only,
591
+ eval_on_start = eval_on_start,
592
+ eval_do_concat_batches = eval_do_concat_batches,
593
+ eval_use_gather_object = eval_use_gather_object,
594
+ eval_accumulation_steps = eval_accumulation_steps,
595
+ batch_eval_metrics = batch_eval_metrics,
596
+ save_only_model = save_only_model,
597
+ save_strategy = save_strategy,
598
+ save_steps = save_steps,
599
+ save_on_each_node = save_on_each_node,
600
+ save_total_limit = save_total_limit,
601
+ enable_jit_checkpoint = enable_jit_checkpoint,
602
+ push_to_hub = push_to_hub,
603
+ hub_token = hub_token,
604
+ hub_private_repo = hub_private_repo,
605
+ hub_model_id = hub_model_id,
606
+ hub_strategy = hub_strategy,
607
+ hub_always_push = hub_always_push,
608
+ hub_revision = hub_revision,
609
+ load_best_model_at_end = load_best_model_at_end,
610
+ metric_for_best_model = metric_for_best_model,
611
+ greater_is_better = greater_is_better,
612
+ ignore_data_skip = ignore_data_skip,
613
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
614
+ full_determinism = full_determinism,
615
+ seed = seed,
616
+ data_seed = data_seed,
617
+ use_cpu = use_cpu,
618
+ accelerator_config = accelerator_config,
619
+ parallelism_config = parallelism_config,
620
+ dataloader_drop_last = dataloader_drop_last,
621
+ dataloader_num_workers = dataloader_num_workers,
622
+ dataloader_pin_memory = dataloader_pin_memory,
623
+ dataloader_persistent_workers = dataloader_persistent_workers,
624
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
625
+ remove_unused_columns = remove_unused_columns,
626
+ label_names = label_names,
627
+ train_sampling_strategy = train_sampling_strategy,
628
+ length_column_name = length_column_name,
629
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
630
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
631
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
632
+ ddp_backend = ddp_backend,
633
+ ddp_timeout = ddp_timeout,
634
+ fsdp = fsdp,
635
+ fsdp_config = fsdp_config,
636
+ deepspeed = deepspeed,
637
+ debug = debug,
638
+ skip_memory_metrics = skip_memory_metrics,
639
+ do_train = do_train,
640
+ do_eval = do_eval,
641
+ do_predict = do_predict,
642
+ resume_from_checkpoint = resume_from_checkpoint,
643
+ warmup_ratio = warmup_ratio,
644
+ logging_dir = logging_dir,
645
+ local_rank = local_rank,
646
+ max_length = max_length,
647
+ max_prompt_length = max_prompt_length,
648
+ max_completion_length = max_completion_length,
649
+ beta = beta,
650
+ disable_dropout = disable_dropout,
651
+ label_pad_token_id = label_pad_token_id,
652
+ padding_value = padding_value,
653
+ truncation_mode = truncation_mode,
654
+ generate_during_eval = generate_during_eval,
655
+ is_encoder_decoder = is_encoder_decoder,
656
+ model_init_kwargs = model_init_kwargs,
657
+ dataset_num_proc = dataset_num_proc,**kwargs)
658
+ self.vllm_sampling_params = vllm_sampling_params
659
+ self.unsloth_num_chunks = unsloth_num_chunks
660
+ if unsloth_grpo_mini_batch is not None:
661
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
662
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
663
+ else:
664
+ raise ValueError(
665
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
666
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
667
+ )
668
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
669
+ self.max_seq_length = max_seq_length
670
+
671
+ pass
672
+
673
+ class _UnslothORPOTrainer(BaseTrainer):
674
+ r""""""
675
+
676
+ _tag_names = ["trl", "orpo"]
677
+ _name = "ORPO"
678
+ _paper = {
679
+ "title": "ORPO: Monolithic Preference Optimization without Reference Model",
680
+ "id": "2403.07691",
681
+ # docstyle-ignore
682
+ "citation": textwrap.dedent("""\
683
+ @article{hong2024orpo,
684
+ title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
685
+ author = {Jiwoo Hong and Noah Lee and James Thorne},
686
+ year = 2024,
687
+ eprint = {arXiv:2403.07691}
688
+ }"""),
689
+ }
690
+
691
+ def __init__(
692
+ self,
693
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
694
+ args: Optional[ORPOConfig] = None,
695
+ data_collator: Optional[DataCollator] = None,
696
+ train_dataset: Optional[Dataset] = None,
697
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
698
+ processing_class: Optional[
699
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
700
+ ] = None,
701
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
702
+ callbacks: Optional[list[TrainerCallback]] = None,
703
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
704
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
705
+ peft_config: Optional[dict] = None,
706
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
707
+ ):
708
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
709
+ warnings.warn(
710
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
711
+ "it and want it to remain, please share your comments here: "
712
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
713
+ "TRL_EXPERIMENTAL_SILENCE=1."
714
+ )
715
+ if args.model_init_kwargs is None:
716
+ model_init_kwargs = {}
717
+ elif not isinstance(model, str):
718
+ raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
719
+ else:
720
+ model_init_kwargs = args.model_init_kwargs
721
+ dtype = model_init_kwargs.get("dtype")
722
+ if dtype is not None:
723
+ # Convert to `torch.dtype` if an str is passed
724
+ if isinstance(dtype, str) and dtype != "auto":
725
+ dtype = getattr(torch, dtype)
726
+ if dtype != "auto" and not isinstance(dtype, torch.dtype):
727
+ raise ValueError(
728
+ f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
729
+ )
730
+ model_init_kwargs["dtype"] = dtype
731
+
732
+ if isinstance(model, str):
733
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
734
+
735
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
736
+ # has been called in order to properly call autocast if needed.
737
+ self._peft_has_been_casted_to_bf16 = False
738
+
739
+ if not is_peft_available() and peft_config is not None:
740
+ raise ValueError(
741
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
742
+ )
743
+ elif is_peft_available() and peft_config is not None:
744
+ # if model is a peft model and we have a peft_config, we merge and unload it first
745
+ if isinstance(model, PeftModel):
746
+ model = model.merge_and_unload()
747
+
748
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
749
+ _support_gc_kwargs = hasattr(
750
+ args, "gradient_checkpointing_kwargs"
751
+ ) and "gradient_checkpointing_kwargs" in list(
752
+ inspect.signature(prepare_model_for_kbit_training).parameters
753
+ )
754
+
755
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
756
+
757
+ if _support_gc_kwargs:
758
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
759
+
760
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
761
+ elif args.gradient_checkpointing:
762
+ # For backward compatibility with older versions of transformers
763
+ if hasattr(model, "enable_input_require_grads"):
764
+ model.enable_input_require_grads()
765
+ else:
766
+
767
+ def make_inputs_require_grad(module, input, output):
768
+ output.requires_grad_(True)
769
+
770
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
771
+
772
+ # get peft model with the given config
773
+ model = model
774
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
775
+ peft_module_casting_to_bf16(model)
776
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
777
+ self._peft_has_been_casted_to_bf16 = True
778
+
779
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
780
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
781
+ # fail or completely fail.
782
+ elif args.gradient_checkpointing:
783
+ # For backward compatibility with older versions of transformers
784
+ if hasattr(model, "enable_input_require_grads"):
785
+ model.enable_input_require_grads()
786
+ else:
787
+
788
+ def make_inputs_require_grad(module, input, output):
789
+ output.requires_grad_(True)
790
+
791
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
792
+
793
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
794
+ raise ValueError(
795
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
796
+ " Please install `wandb` or `comet-ml` to resolve."
797
+ )
798
+
799
+ if model is not None:
800
+ self.is_encoder_decoder = model.config.is_encoder_decoder
801
+ elif args.is_encoder_decoder is None:
802
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
803
+ else:
804
+ self.is_encoder_decoder = args.is_encoder_decoder
805
+
806
+ if self.is_encoder_decoder:
807
+ self.decoder_start_token_id = model.config.decoder_start_token_id
808
+ self.pad_token_id = model.config.pad_token_id
809
+
810
+ if processing_class is None:
811
+ raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
812
+ if args.max_length is None:
813
+ logger.warning(
814
+ "`max_length` is not set in the ORPOConfig's init"
815
+ " it will default to `512` by default, but you should do it yourself in the future.",
816
+ )
817
+ max_length = 512
818
+ else:
819
+ max_length = args.max_length
820
+ if args.max_prompt_length is None:
821
+ logger.warning(
822
+ "`max_prompt_length` is not set in the ORPOConfig's init"
823
+ " it will default to `128` by default, but you should do it yourself in the future.",
824
+ )
825
+ max_prompt_length = 128
826
+ else:
827
+ max_prompt_length = args.max_prompt_length
828
+
829
+ if args.max_completion_length is None and self.is_encoder_decoder:
830
+ logger.warning(
831
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
832
+ " it will default to `128` by default, but you should do it yourself in the future.",
833
+ )
834
+ self.max_completion_length = 128
835
+ else:
836
+ self.max_completion_length = args.max_completion_length
837
+
838
+ if data_collator is None:
839
+ data_collator = DPODataCollatorWithPadding(
840
+ pad_token_id=processing_class.pad_token_id,
841
+ label_pad_token_id=args.label_pad_token_id,
842
+ is_encoder_decoder=self.is_encoder_decoder,
843
+ )
844
+
845
+ if args.remove_unused_columns:
846
+ args.remove_unused_columns = False
847
+ # warn users
848
+ logger.warning(
849
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
850
+ " we have set it for you, but you should do it yourself in the future.",
851
+ )
852
+
853
+ self.use_dpo_data_collator = True
854
+ else:
855
+ self.use_dpo_data_collator = False
856
+
857
+ # Disable dropout in the model and reference model
858
+ if args.disable_dropout:
859
+ disable_dropout_in_model(model)
860
+
861
+ self.max_length = max_length
862
+ self.generate_during_eval = args.generate_during_eval
863
+ self.label_pad_token_id = args.label_pad_token_id
864
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
865
+ self.max_prompt_length = max_prompt_length
866
+ self.truncation_mode = args.truncation_mode
867
+ self.processing_class = processing_class
868
+
869
+ self.beta = args.beta
870
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
871
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
872
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
873
+ logger.warning(
874
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
875
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
876
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
877
+ "loss.",
878
+ )
879
+
880
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
881
+
882
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
883
+ # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
884
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
885
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
886
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
887
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
888
+ # that the warning has already been issued.
889
+ model.warnings_issued["estimate_tokens"] = True
890
+
891
+ # Compute that only on the main process for faster data processing.
892
+ # see: https://github.com/huggingface/trl/pull/1255
893
+ with PartialState().main_process_first():
894
+ # Extract the prompt if needed, and apply the chat template if needed
895
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
896
+ train_dataset = train_dataset.map(
897
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
898
+ )
899
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
900
+ if eval_dataset is not None:
901
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
902
+ eval_dataset = eval_dataset.map(
903
+ maybe_apply_chat_template,
904
+ fn_kwargs={"tokenizer": processing_class},
905
+ num_proc=args.dataset_num_proc,
906
+ )
907
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
908
+
909
+ super().__init__(
910
+ model=model,
911
+ args=args,
912
+ data_collator=data_collator,
913
+ train_dataset=train_dataset,
914
+ eval_dataset=eval_dataset,
915
+ processing_class=processing_class,
916
+ model_init=model_init,
917
+ compute_metrics=compute_metrics,
918
+ callbacks=callbacks,
919
+ optimizers=optimizers,
920
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
921
+ )
922
+
923
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
924
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
925
+ # self.model_accepts_loss_kwargs to False to enable scaling.
926
+ self.model_accepts_loss_kwargs = False
927
+
928
+ # Add tags for models that have been loaded with the correct transformers version
929
+ if hasattr(self.model, "add_model_tags"):
930
+ self.model.add_model_tags(self._tag_names)
931
+
932
+ if not hasattr(self, "accelerator"):
933
+ raise AttributeError(
934
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
935
+ )
936
+
937
+ def build_tokenized_answer(self, prompt, answer):
938
+ """
939
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
940
+ b)[len(enc(a)):]`. Reference:
941
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
942
+ """
943
+
944
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
945
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
946
+
947
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
948
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
949
+
950
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
951
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
952
+
953
+ # Prepare input tokens for token by token comparison
954
+ full_input_ids = np.array(full_tokenized["input_ids"])
955
+
956
+ if len(full_input_ids) != len(full_concat_input_ids):
957
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
958
+
959
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
960
+ # can be merged together when tokenizing prompt+answer. This could result
961
+ # on the last token from the prompt being different when tokenized on its own
962
+ # vs when done as prompt+answer.
963
+ response_token_ids_start_idx = len(prompt_input_ids)
964
+
965
+ # If tokenized prompt is different than both prompt+answer, then it means the
966
+ # last token has changed due to merging.
967
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
968
+ response_token_ids_start_idx -= 1
969
+
970
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
971
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
972
+
973
+ if len(prompt_input_ids) != len(prompt_attention_mask):
974
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
975
+
976
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
977
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
978
+
979
+ return dict(
980
+ prompt_input_ids=prompt_input_ids,
981
+ prompt_attention_mask=prompt_attention_mask,
982
+ input_ids=answer_input_ids,
983
+ attention_mask=answer_attention_mask,
984
+ )
985
+
986
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
987
+ """Tokenize a single row from a ORPO specific dataset.
988
+
989
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
990
+ chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
991
+ we truncate the chosen/rejected.
992
+
993
+ We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
994
+ of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
995
+ """
996
+ batch = {}
997
+ prompt = feature["prompt"]
998
+ chosen = feature["chosen"]
999
+ rejected = feature["rejected"]
1000
+
1001
+ if not self.is_encoder_decoder:
1002
+ # Check issues below for more details
1003
+ # 1. https://github.com/huggingface/trl/issues/907
1004
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
1005
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
1006
+
1007
+ if not isinstance(prompt, str):
1008
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
1009
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
1010
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
1011
+
1012
+ if not isinstance(chosen, str):
1013
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
1014
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
1015
+
1016
+ if not isinstance(rejected, str):
1017
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
1018
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
1019
+
1020
+ # Last prompt token might get merged by tokenizer and
1021
+ # it should not be included for generation if that happens
1022
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
1023
+
1024
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
1025
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
1026
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
1027
+
1028
+ for k, v in prompt_tokens.items():
1029
+ prompt_tokens[k] = v[:prompt_len_input_ids]
1030
+
1031
+ # Make sure prompts only have one different token at most an
1032
+ # and length only differs by 1 at most
1033
+ num_diff_tokens = sum(
1034
+ a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
1035
+ )
1036
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
1037
+ if num_diff_tokens > 1 or num_diff_len > 1:
1038
+ raise ValueError(
1039
+ "Chosen and rejected prompt_input_ids might only differ on the "
1040
+ "last token due to tokenizer merge ops."
1041
+ )
1042
+
1043
+ # add BOS token to head of prompt. Avoid adding if it's already there
1044
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
1045
+ self.processing_class.bos_token_id,
1046
+ prompt_len_input_ids,
1047
+ prompt_tokens,
1048
+ chosen_prompt_len_input_ids,
1049
+ chosen_tokens,
1050
+ rejected_prompt_len_input_ids,
1051
+ rejected_tokens,
1052
+ )
1053
+
1054
+ # add EOS token to end of answer. Avoid adding if it's already there
1055
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
1056
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
1057
+ )
1058
+
1059
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
1060
+
1061
+ # if combined sequence is too long, truncate the prompt
1062
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
1063
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
1064
+ if self.truncation_mode == "keep_start":
1065
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
1066
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
1067
+ elif self.truncation_mode == "keep_end":
1068
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
1069
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
1070
+ else:
1071
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
1072
+
1073
+ # if that's still too long, truncate the response
1074
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
1075
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
1076
+ for k in ["input_ids", "attention_mask"]:
1077
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
1078
+
1079
+ # Create labels
1080
+ chosen_sequence_tokens = {
1081
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
1082
+ }
1083
+ rejected_sequence_tokens = {
1084
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
1085
+ }
1086
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
1087
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
1088
+ self.label_pad_token_id
1089
+ ] * len(chosen_tokens["prompt_input_ids"])
1090
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
1091
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
1092
+ self.label_pad_token_id
1093
+ ] * len(rejected_tokens["prompt_input_ids"])
1094
+
1095
+ for k, toks in {
1096
+ "chosen_": chosen_sequence_tokens,
1097
+ "rejected_": rejected_sequence_tokens,
1098
+ "": prompt_tokens,
1099
+ }.items():
1100
+ for type_key, tokens in toks.items():
1101
+ if type_key == "token_type_ids":
1102
+ continue
1103
+ batch[f"{k}{type_key}"] = tokens
1104
+
1105
+ else:
1106
+ chosen_tokens = self.processing_class(
1107
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
1108
+ )
1109
+ rejected_tokens = self.processing_class(
1110
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
1111
+ )
1112
+ prompt_tokens = self.processing_class(
1113
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
1114
+ )
1115
+
1116
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
1117
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
1118
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
1119
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
1120
+
1121
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
1122
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
1123
+ labels=torch.tensor(batch["rejected_labels"])
1124
+ )
1125
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
1126
+ labels=torch.tensor(batch["chosen_labels"])
1127
+ )
1128
+
1129
+ if is_torch_xla_available():
1130
+ # Pad the sequences to global max_length to avoid TorchXLA recompilation
1131
+ for k in batch:
1132
+ if "labels" in k or self.is_encoder_decoder:
1133
+ pad_value = self.label_pad_token_id
1134
+ elif k.endswith("_input_ids"):
1135
+ pad_value = self.padding_value
1136
+ elif k.endswith("_attention_mask"):
1137
+ pad_value = 0
1138
+ batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
1139
+ return batch
1140
+
1141
+ @staticmethod
1142
+ def concatenated_inputs(
1143
+ batch: dict[str, Union[list, torch.LongTensor]],
1144
+ is_encoder_decoder: bool = False,
1145
+ label_pad_token_id: int = -100,
1146
+ padding_value: int = 0,
1147
+ device: Optional[torch.device] = None,
1148
+ ) -> dict[str, torch.LongTensor]:
1149
+ """Concatenate the chosen and rejected inputs into a single tensor.
1150
+
1151
+ Args:
1152
+ batch:
1153
+ A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
1154
+ of shape (batch_size, sequence_length).
1155
+ is_encoder_decoder:
1156
+ Whether the model is an encoder-decoder model.
1157
+ label_pad_token_id:
1158
+ The label pad token id.
1159
+ padding_value:
1160
+ The padding value to use for the concatenated inputs_ids.
1161
+ device:
1162
+ The device for the concatenated inputs.
1163
+
1164
+ Returns:
1165
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
1166
+ """
1167
+ concatenated_batch = {}
1168
+
1169
+ if is_encoder_decoder:
1170
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
1171
+ else:
1172
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
1173
+
1174
+ for k in batch:
1175
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
1176
+ if "labels" in k or is_encoder_decoder:
1177
+ pad_value = label_pad_token_id
1178
+ elif k.endswith("_input_ids"):
1179
+ pad_value = padding_value
1180
+ elif k.endswith("_attention_mask"):
1181
+ pad_value = 0
1182
+ concatenated_key = k.replace("chosen", "concatenated")
1183
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
1184
+ for k in batch:
1185
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
1186
+ if "labels" in k or is_encoder_decoder:
1187
+ pad_value = label_pad_token_id
1188
+ elif k.endswith("_input_ids"):
1189
+ pad_value = padding_value
1190
+ elif k.endswith("_attention_mask"):
1191
+ pad_value = 0
1192
+ concatenated_key = k.replace("rejected", "concatenated")
1193
+ concatenated_batch[concatenated_key] = torch.cat(
1194
+ (
1195
+ concatenated_batch[concatenated_key],
1196
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
1197
+ ),
1198
+ dim=0,
1199
+ ).to(device=device)
1200
+
1201
+ if is_encoder_decoder:
1202
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
1203
+ concatenated_batch["concatenated_attention_mask"] = (
1204
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
1205
+ )
1206
+
1207
+ return concatenated_batch
1208
+
1209
+ def odds_ratio_loss(
1210
+ self,
1211
+ policy_chosen_logps: torch.FloatTensor,
1212
+ policy_rejected_logps: torch.FloatTensor,
1213
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1214
+ """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
1215
+
1216
+ Args:
1217
+ policy_chosen_logps:
1218
+ Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
1219
+ policy_rejected_logps:
1220
+ Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
1221
+
1222
+ Returns:
1223
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO
1224
+ loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
1225
+ the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the
1226
+ rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes.
1227
+ """
1228
+
1229
+ # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
1230
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
1231
+ torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
1232
+ )
1233
+ ratio = F.logsigmoid(log_odds)
1234
+ losses = self.beta * ratio
1235
+
1236
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
1237
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
1238
+
1239
+ return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
1240
+
1241
+ @staticmethod
1242
+ def get_batch_logps(
1243
+ logits: torch.FloatTensor,
1244
+ labels: torch.LongTensor,
1245
+ average_log_prob: bool = False,
1246
+ label_pad_token_id: int = -100,
1247
+ is_encoder_decoder: bool = False,
1248
+ ) -> torch.FloatTensor:
1249
+ """Compute the log probabilities of the given labels under the given logits.
1250
+
1251
+ Args:
1252
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1253
+ labels:
1254
+ Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
1255
+ ignored. Shape: (batch_size, sequence_length)
1256
+ average_log_prob:
1257
+ If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
1258
+ log probabilities of the (non-masked) tokens.
1259
+ label_pad_token_id: The label pad token id.
1260
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
1261
+
1262
+ Returns:
1263
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
1264
+ given logits.
1265
+ """
1266
+ if logits.shape[:-1] != labels.shape:
1267
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1268
+
1269
+ if not is_encoder_decoder:
1270
+ labels = labels[:, 1:].clone()
1271
+ logits = logits[:, :-1, :]
1272
+ loss_mask = labels != label_pad_token_id
1273
+
1274
+ # dummy token; we'll ignore the losses on these tokens later
1275
+ labels = torch.where(labels == label_pad_token_id, 0, labels)
1276
+
1277
+ per_token_logps = selective_log_softmax(logits, labels)
1278
+
1279
+ if average_log_prob:
1280
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1281
+ else:
1282
+ return (per_token_logps * loss_mask).sum(-1)
1283
+
1284
+ def concatenated_forward(
1285
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1286
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1287
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1288
+
1289
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1290
+ """
1291
+ concatenated_batch = self.concatenated_inputs(
1292
+ batch,
1293
+ is_encoder_decoder=self.is_encoder_decoder,
1294
+ label_pad_token_id=self.label_pad_token_id,
1295
+ padding_value=self.padding_value,
1296
+ device=self.accelerator.device,
1297
+ )
1298
+ len_chosen = batch["chosen_labels"].shape[0]
1299
+
1300
+ model_kwargs = (
1301
+ {
1302
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1303
+ }
1304
+ if self.is_encoder_decoder
1305
+ else {}
1306
+ )
1307
+
1308
+ if self.aux_loss_enabled:
1309
+ model_kwargs["output_router_logits"] = True
1310
+
1311
+ outputs = model(
1312
+ concatenated_batch["concatenated_input_ids"],
1313
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1314
+ use_cache=False,
1315
+ **model_kwargs,
1316
+ )
1317
+ all_logits = outputs.logits
1318
+
1319
+ def cross_entropy_loss(logits, labels):
1320
+ if not self.is_encoder_decoder:
1321
+ # Shift so that tokens < n predict n
1322
+ logits = logits[..., :-1, :].contiguous()
1323
+ labels = labels[..., 1:].contiguous()
1324
+ # Flatten the tokens
1325
+ loss_fct = nn.CrossEntropyLoss()
1326
+ logits = logits.view(-1, logits.shape[-1])
1327
+ labels = labels.view(-1)
1328
+ # Enable model parallelism
1329
+ labels = labels.to(logits.device)
1330
+ loss = loss_fct(logits, labels)
1331
+ return loss
1332
+
1333
+ if self.is_encoder_decoder:
1334
+ labels = concatenated_batch["concatenated_labels"].clone()
1335
+ else:
1336
+ labels = concatenated_batch["concatenated_input_ids"].clone()
1337
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
1338
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1339
+ # orpo chosen nll loss is computed over the full prompt and response
1340
+ chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1341
+
1342
+ all_logps = self.get_batch_logps(
1343
+ all_logits,
1344
+ concatenated_batch["concatenated_labels"],
1345
+ average_log_prob=True,
1346
+ is_encoder_decoder=self.is_encoder_decoder,
1347
+ label_pad_token_id=self.label_pad_token_id,
1348
+ )
1349
+
1350
+ chosen_logps = all_logps[:len_chosen]
1351
+ rejected_logps = all_logps[len_chosen:]
1352
+
1353
+ if not self.is_encoder_decoder:
1354
+ chosen_logits = all_logits[:len_chosen, :-1, :]
1355
+ rejected_logits = all_logits[len_chosen:, :-1, :]
1356
+ else:
1357
+ chosen_logits = all_logits[:len_chosen]
1358
+ rejected_logits = all_logits[len_chosen:]
1359
+
1360
+ if self.aux_loss_enabled:
1361
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1362
+
1363
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1364
+
1365
+ def get_batch_loss_metrics(
1366
+ self,
1367
+ model,
1368
+ batch: dict[str, Union[list, torch.LongTensor]],
1369
+ train_eval: Literal["train", "eval"] = "train",
1370
+ ):
1371
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1372
+ metrics = {}
1373
+
1374
+ forward_output = self.concatenated_forward(model, batch)
1375
+ (
1376
+ policy_chosen_logps,
1377
+ policy_rejected_logps,
1378
+ policy_chosen_logits,
1379
+ policy_rejected_logits,
1380
+ policy_nll_loss,
1381
+ ) = forward_output[:5]
1382
+ if self.aux_loss_enabled:
1383
+ aux_loss = forward_output[5]
1384
+
1385
+ losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1386
+ policy_chosen_logps, policy_rejected_logps
1387
+ )
1388
+ # full ORPO loss
1389
+ loss = policy_nll_loss - losses.mean()
1390
+
1391
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1392
+
1393
+ prefix = "eval_" if train_eval == "eval" else ""
1394
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1395
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1396
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1397
+ metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1398
+ chosen_rewards - rejected_rewards
1399
+ ).mean()
1400
+ metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1401
+ metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1402
+ metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
1403
+ policy_rejected_logits.detach().mean()
1404
+ ).mean()
1405
+ metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
1406
+ policy_chosen_logits.detach().mean()
1407
+ ).mean()
1408
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1409
+ metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
1410
+ metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
1411
+ if is_torch_xla_available():
1412
+ xm.mark_step() # needed because .item() calls
1413
+ for k, v in metrics.items():
1414
+ metrics[k] = v.item()
1415
+ if self.aux_loss_enabled:
1416
+ loss += self.aux_loss_coef * aux_loss
1417
+
1418
+ return loss, metrics
1419
+
1420
+ def compute_loss(
1421
+ self,
1422
+ model: Union[PreTrainedModel, nn.Module],
1423
+ inputs: dict[str, Union[torch.Tensor, Any]],
1424
+ return_outputs=False,
1425
+ num_items_in_batch=None,
1426
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1427
+ compute_loss_context_manager = (
1428
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1429
+ )
1430
+
1431
+ with compute_loss_context_manager:
1432
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1433
+
1434
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1435
+ loss = loss.to(self.args.device)
1436
+
1437
+ # force log the metrics
1438
+ self.store_metrics(metrics, train_eval="train")
1439
+
1440
+ if return_outputs:
1441
+ return (loss, metrics)
1442
+ return loss
1443
+
1444
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1445
+ """Generate samples from the model and reference model for the given batch of inputs."""
1446
+
1447
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1448
+ # the torch amp context manager as some hidden states are silently casted to full precision.
1449
+ generate_context_manager = (
1450
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1451
+ )
1452
+
1453
+ with generate_context_manager:
1454
+ policy_output = model.generate(
1455
+ input_ids=batch["prompt_input_ids"],
1456
+ attention_mask=batch["prompt_attention_mask"],
1457
+ max_length=self.max_length,
1458
+ do_sample=True,
1459
+ pad_token_id=self.processing_class.pad_token_id,
1460
+ )
1461
+
1462
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1463
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1464
+
1465
+ return policy_output_decoded
1466
+
1467
+ def prediction_step(
1468
+ self,
1469
+ model: Union[PreTrainedModel, nn.Module],
1470
+ inputs: dict[str, Union[torch.Tensor, Any]],
1471
+ prediction_loss_only: bool,
1472
+ ignore_keys: Optional[list[str]] = None,
1473
+ ):
1474
+ if not self.use_dpo_data_collator:
1475
+ logger.warning(
1476
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1477
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1478
+ )
1479
+ if ignore_keys is None:
1480
+ if hasattr(model, "config"):
1481
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1482
+ else:
1483
+ ignore_keys = []
1484
+
1485
+ prediction_context_manager = (
1486
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1487
+ )
1488
+
1489
+ with torch.no_grad(), prediction_context_manager:
1490
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1491
+
1492
+ # force log the metrics
1493
+ self.store_metrics(metrics, train_eval="eval")
1494
+
1495
+ if prediction_loss_only:
1496
+ return (loss.detach(), None, None)
1497
+
1498
+ # logits for the chosen and rejected samples from model
1499
+ logits_dict = {
1500
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1501
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1502
+ }
1503
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1504
+ logits = torch.tensor(logits, device=self.accelerator.device)
1505
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1506
+
1507
+ return (loss.detach(), logits, labels)
1508
+
1509
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1510
+ for key, value in metrics.items():
1511
+ self._stored_metrics[train_eval][key].append(value)
1512
+
1513
+ def evaluation_loop(
1514
+ self,
1515
+ dataloader: DataLoader,
1516
+ description: str,
1517
+ prediction_loss_only: Optional[bool] = None,
1518
+ ignore_keys: Optional[list[str]] = None,
1519
+ metric_key_prefix: str = "eval",
1520
+ ) -> EvalLoopOutput:
1521
+ """
1522
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
1523
+ `Trainer.evaluate()` and `Trainer.predict()`.
1524
+
1525
+ Works both with or without labels.
1526
+ """
1527
+
1528
+ # Sample and save to game log if requested (for one batch to save time)
1529
+ if self.generate_during_eval:
1530
+ # Generate random indices within the range of the total number of samples
1531
+ num_samples = len(dataloader.dataset)
1532
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1533
+
1534
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1535
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1536
+ random_batch = self.data_collator(random_batch_dataset)
1537
+ random_batch = self._prepare_inputs(random_batch)
1538
+
1539
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1540
+
1541
+ table = pd.DataFrame(
1542
+ columns=["Prompt", "Policy"],
1543
+ data=[
1544
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1545
+ ],
1546
+ )
1547
+ if "wandb" in self.args.report_to:
1548
+ wandb.log({"game_log": wandb.Table(data=table)})
1549
+
1550
+ if "comet_ml" in self.args.report_to:
1551
+ log_table_to_comet_experiment(
1552
+ name="game_log.csv",
1553
+ table=table,
1554
+ )
1555
+
1556
+ # Base evaluation
1557
+ initial_output = super().evaluation_loop(
1558
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1559
+ )
1560
+
1561
+ return initial_output
1562
+
1563
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1564
+ """
1565
+ Log `logs` on the various objects watching training, including stored metrics.
1566
+
1567
+ Args:
1568
+ logs (`dict[str, float]`):
1569
+ The values to log.
1570
+ start_time (`float`, *optional*):
1571
+ Start time of the training.
1572
+ """
1573
+ # logs either has 'loss' or 'eval_loss'
1574
+ train_eval = "train" if "loss" in logs else "eval"
1575
+ # Add averaged stored metrics to logs
1576
+ for key, metrics in self._stored_metrics[train_eval].items():
1577
+ logs[key] = torch.tensor(metrics).mean().item()
1578
+ del self._stored_metrics[train_eval]
1579
+ return super().log(logs, start_time)
1580
+
1581
+ def _shift_right(self, input_ids):
1582
+ if self.decoder_start_token_id is None:
1583
+ raise ValueError(
1584
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1585
+ )
1586
+
1587
+ # shift inputs to the right
1588
+ if is_torch_fx_proxy(input_ids):
1589
+ # Item assignment is not supported natively for proxies.
1590
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1591
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1592
+ else:
1593
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1594
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1595
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1596
+
1597
+ if self.pad_token_id is None:
1598
+ raise ValueError("model.config.pad_token_id has to be defined.")
1599
+ # replace possible -100 values in labels by `pad_token_id`
1600
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1601
+
1602
+ return shifted_input_ids
1603
+
1604
+ # Ensure the model card is saved along with the checkpoint
1605
+ def _save_checkpoint(self, model, trial):
1606
+ if self.args.hub_model_id is None:
1607
+ model_name = Path(self.args.output_dir).name
1608
+ else:
1609
+ model_name = self.args.hub_model_id.split("/")[-1]
1610
+ self.create_model_card(model_name=model_name)
1611
+ super()._save_checkpoint(model, trial)
1612
+ class UnslothORPOTrainer(_UnslothORPOTrainer):
1613
+ """
1614
+
1615
+ Initialize ORPOTrainer.
1616
+
1617
+ Args:
1618
+ model ([`~transformers.PreTrainedModel`]):
1619
+ The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
1620
+ args ([`ORPOConfig`]):
1621
+ The ORPO config arguments to use for training.
1622
+ data_collator ([`~transformers.DataCollator`]):
1623
+ The data collator to use for training. If None is specified, the default data collator
1624
+ ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
1625
+ sequences in the batch, given a dataset of paired sequences.
1626
+ train_dataset ([`~datasets.Dataset`]):
1627
+ The dataset to use for training.
1628
+ eval_dataset ([`~datasets.Dataset`]):
1629
+ The dataset to use for evaluation.
1630
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
1631
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1632
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1633
+ reuse the fine-tuned model.
1634
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1635
+ The model initializer to use for training. If None is specified, the default model initializer will be
1636
+ used.
1637
+ callbacks (`list[transformers.TrainerCallback]`):
1638
+ The callbacks to use for training.
1639
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1640
+ The optimizer and scheduler to use for training.
1641
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1642
+ The function to use to preprocess the logits before computing the metrics.
1643
+ peft_config (`dict`, defaults to `None`):
1644
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
1645
+ a PEFT model.
1646
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1647
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1648
+ metric values.
1649
+
1650
+ """
1651
+ def __init__(
1652
+ self,
1653
+ model = None,
1654
+ args = None,
1655
+ data_collator = None,
1656
+ train_dataset = None,
1657
+ eval_dataset = None,
1658
+ processing_class = None,
1659
+ model_init = None,
1660
+ callbacks = None,
1661
+ preprocess_logits_for_metrics = None,
1662
+ peft_config = None,
1663
+ compute_metrics = None,
1664
+ **kwargs
1665
+ ):
1666
+ if args is None: args = UnslothORPOConfig()
1667
+ use_bf16 = getattr(args, 'bf16', False)
1668
+ if type(use_bf16) is not bool: use_bf16 = False
1669
+ use_fp16 = getattr(args, 'fp16', False)
1670
+ if type(use_fp16) is not bool: use_fp16 = False
1671
+ force_float32 = False
1672
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1673
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1674
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1675
+ force_float32 = True
1676
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1677
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1678
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1679
+ from unsloth_zoo.utils import _get_dtype
1680
+ dtype = _get_dtype(dtype)
1681
+ float16 = dtype == torch.float16
1682
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1683
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1684
+ if force_float32:
1685
+ # Forced float32 training
1686
+ args.fp16 = False
1687
+ args.bf16 = False
1688
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1689
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1690
+ # args.mixed_precision is a new argument which needs to be set now
1691
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1692
+ # Mixed precision training
1693
+ args.fp16 = float16
1694
+ args.bf16 = not float16
1695
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1696
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1697
+ # args.mixed_precision is a new argument which needs to be set now
1698
+ elif mixed_precision_dtype == 'bfloat16':
1699
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1700
+ args.fp16 = False
1701
+ args.bf16 = False
1702
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1703
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1704
+ # args.mixed_precision is a new argument which needs to be set now
1705
+
1706
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1707
+ args.eval_strategy = 'steps'
1708
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1709
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1710
+ if ga_steps is not None and ga_steps > 1:
1711
+ from transformers import __version__ as transformers_version
1712
+ if Version(transformers_version) <= Version('4.45.2'):
1713
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1714
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1715
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1716
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1717
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1718
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1719
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1720
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1721
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1722
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1723
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1724
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1725
+ if force_float32:
1726
+ args.bf16_full_eval = False
1727
+ args.fp16_full_eval = False
1728
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1729
+ args.bf16_full_eval = True
1730
+ args.fp16_full_eval = False
1731
+ elif not bf16_full_eval and not fp16_full_eval:
1732
+ args.bf16_full_eval = args.bf16
1733
+ args.fp16_full_eval = args.fp16
1734
+ _output_logits = False
1735
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1736
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1737
+ if _output_logits:
1738
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1739
+ if model is not None:
1740
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1741
+ if _warnings_issued is None:
1742
+ model.warnings_issued = {}
1743
+ elif not isinstance(_warnings_issued, dict):
1744
+ try:
1745
+ model.warnings_issued = dict(_warnings_issued)
1746
+ except Exception:
1747
+ model.warnings_issued = {}
1748
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1749
+ pass
1750
+ else:
1751
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1752
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1753
+ if args_max_seq_length is None and model_max_seq_length is not None:
1754
+ max_seq_length = model.max_seq_length
1755
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1756
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1757
+ if args_max_seq_length > model_max_seq_length:
1758
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1759
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1760
+ args.max_seq_length = model_max_seq_length
1761
+ if model is not None and hasattr(model, 'for_training'):
1762
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1763
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1764
+ if 'processing_class' in locals():
1765
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1766
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1767
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1768
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1769
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1770
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1771
+ data_collator = TransformersDataCollatorForLanguageModeling(
1772
+ __tokenizer,
1773
+ mlm = False,
1774
+ mlm_probability = 0.0,
1775
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1776
+ )
1777
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1778
+ data_collator = DataCollatorForSeq2Seq(
1779
+ __tokenizer,
1780
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1781
+ )
1782
+ else:
1783
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1784
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1785
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1786
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1787
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1788
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1789
+ data_collator = DataCollatorForSeq2Seq(
1790
+ __tokenizer.tokenizer,
1791
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1792
+ )
1793
+ else:
1794
+ data_collator = TransformersDataCollatorForLanguageModeling(
1795
+ __tokenizer.tokenizer,
1796
+ mlm = False,
1797
+ mlm_probability = 0.0,
1798
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1799
+ )
1800
+ other_metrics = []
1801
+
1802
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1803
+ PatchRLStatistics('orpo_trainer', other_metrics)
1804
+
1805
+ # [TODO] Fix up DataParallel multiplying batch sizes
1806
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1807
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1808
+ if getattr(args, "_n_gpu", 1) != 1:
1809
+ args._n_gpu = 1
1810
+ if "model" in locals() and hasattr(model, "for_training"):
1811
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1812
+ super().__init__(
1813
+ model = model,
1814
+ args = args,
1815
+ data_collator = data_collator,
1816
+ train_dataset = train_dataset,
1817
+ eval_dataset = eval_dataset,
1818
+ processing_class = processing_class,
1819
+ model_init = model_init,
1820
+ callbacks = callbacks,
1821
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1822
+ peft_config = peft_config,
1823
+ compute_metrics = compute_metrics,**kwargs)
1824
+ if "model" in locals() and hasattr(model, "for_inference"):
1825
+ model.for_inference()
1826
+ if hasattr(self, 'neftune_hook_handle'):
1827
+ self.neftune_hook_handle.remove()
1828
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1829
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1830
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1831
+ pass
1832
+ if hasattr(self, 'accelerator'):
1833
+ scaler = self.accelerator.scaler
1834
+ current_model = model
1835
+ while hasattr(current_model, 'model'):
1836
+ current_model.accelerator_scaler = scaler
1837
+ current_model = current_model.model
1838
+ current_model.accelerator_scaler = scaler
1839
+ pass
1840
+ if hasattr(self, 'train'):
1841
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1842
+ pass
1843
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1844
+ _vllm_tok = self.llm.get_tokenizer()
1845
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1846
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1847
+ _vllm_tok.chat_template = _pc.chat_template
1848
+ pass
1849
+
1850
+ pass
1851
+
1852
+
1853
+ if hasattr(logger, "addFilter"):
1854
+ import logging
1855
+ class HideLoggingMessage(logging.Filter):
1856
+ def __init__(self, text): self.text = text
1857
+ def filter(self, x): return not (self.text in x.getMessage())
1858
+ pass
1859
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
1860
+
unsloth_compiled_cache/UnslothPPOTrainer.py ADDED
@@ -0,0 +1,1634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, BaseTrainer, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, Path, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, empty_cache, exact_div, first_true_indices, forward, gather_object, gc, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_rich_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, warnings, Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, OnlineTrainerState, Optional, PPOConfig, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, broadcast, create_reference_model, disable_dropout_in_model, exact_div, forward, get_peft_model, get_reporting_integration_callbacks, is_peft_available, math, nn, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, time, torch, warnings, PeftModel, is_peft_available, os, torch)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothPPOConfig(PPOConfig):
325
+ """
326
+
327
+ Configuration class for the [`PPOTrainer`].
328
+
329
+ This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
330
+ please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
331
+ values in this class may differ from those in [`~transformers.TrainingArguments`].
332
+
333
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
334
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
335
+ command line.
336
+
337
+ Parameters:
338
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
339
+ Name of this experiment.
340
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
341
+ Path to the reward model.
342
+ model_adapter_name (`str`, *optional*):
343
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
344
+ ref_adapter_name (`str`, *optional*):
345
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
346
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
347
+ Number of epochs to train.
348
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
349
+ Whether to whiten the rewards.
350
+ kl_coef (`float`, *optional*, defaults to `0.05`):
351
+ KL coefficient.
352
+ kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
353
+ Which estimator for KL-Divergence to use from [Approximating KL
354
+ Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
355
+ estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
356
+ better estimator". Cannot be set to "k2", as it is used for logging purposes.
357
+ cliprange (`float`, *optional*, defaults to `0.2`):
358
+ Clip range.
359
+ vf_coef (`float`, *optional*, defaults to `0.1`):
360
+ Value function coefficient.
361
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
362
+ Clip range for the value function.
363
+ gamma (`float`, *optional*, defaults to `1.0`):
364
+ Discount factor.
365
+ lam (`float`, *optional*, defaults to `0.95`):
366
+ Lambda value for GAE.
367
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
368
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
369
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
370
+ capacity of a single GPU, albeit at the cost of slower generation.
371
+
372
+ """
373
+ vllm_sampling_params: Optional[Any] = field(
374
+ default = None,
375
+ metadata = {'help': 'vLLM SamplingParams'},
376
+ )
377
+ unsloth_num_chunks : Optional[int] = field(
378
+ default = -1,
379
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
380
+ )
381
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
382
+ default = None,
383
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
384
+ )
385
+ unsloth_grpo_mini_batch : Optional[int] = field(
386
+ default = None,
387
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
388
+ )
389
+
390
+ def __init__(
391
+ self,
392
+ output_dir = None,
393
+ per_device_train_batch_size = 4,
394
+ num_train_epochs = 3.0,
395
+ max_steps = -1,
396
+ learning_rate = 5e-05,
397
+ lr_scheduler_type = 'linear',
398
+ lr_scheduler_kwargs = None,
399
+ warmup_steps = 0.1,
400
+ optim = 'adamw_8bit',
401
+ optim_args = None,
402
+ weight_decay = 0.01,
403
+ adam_beta1 = 0.9,
404
+ adam_beta2 = 0.999,
405
+ adam_epsilon = 1e-08,
406
+ optim_target_modules = None,
407
+ gradient_accumulation_steps = 2,
408
+ average_tokens_across_devices = True,
409
+ max_grad_norm = 1.0,
410
+ label_smoothing_factor = 0.0,
411
+ bf16 = False,
412
+ fp16 = False,
413
+ bf16_full_eval = False,
414
+ fp16_full_eval = False,
415
+ tf32 = None,
416
+ gradient_checkpointing = True,
417
+ gradient_checkpointing_kwargs = None,
418
+ torch_compile = False,
419
+ torch_compile_backend = None,
420
+ torch_compile_mode = None,
421
+ use_liger_kernel = False,
422
+ liger_kernel_config = None,
423
+ use_cache = False,
424
+ neftune_noise_alpha = None,
425
+ torch_empty_cache_steps = 250,
426
+ auto_find_batch_size = False,
427
+ logging_strategy = 'steps',
428
+ logging_steps = 1,
429
+ logging_first_step = False,
430
+ log_on_each_node = True,
431
+ logging_nan_inf_filter = False,
432
+ include_num_input_tokens_seen = False,
433
+ log_level = 'passive',
434
+ log_level_replica = 'warning',
435
+ disable_tqdm = None,
436
+ report_to = 'none',
437
+ run_name = None,
438
+ project = 'huggingface',
439
+ trackio_space_id = 'trackio',
440
+ eval_strategy = 'no',
441
+ eval_steps = None,
442
+ eval_delay = 0,
443
+ per_device_eval_batch_size = 4,
444
+ prediction_loss_only = False,
445
+ eval_on_start = False,
446
+ eval_do_concat_batches = True,
447
+ eval_use_gather_object = False,
448
+ eval_accumulation_steps = 2,
449
+ batch_eval_metrics = False,
450
+ save_only_model = False,
451
+ save_strategy = 'steps',
452
+ save_steps = 500,
453
+ save_on_each_node = False,
454
+ save_total_limit = None,
455
+ enable_jit_checkpoint = False,
456
+ push_to_hub = False,
457
+ hub_token = None,
458
+ hub_private_repo = None,
459
+ hub_model_id = None,
460
+ hub_strategy = 'every_save',
461
+ hub_always_push = False,
462
+ hub_revision = None,
463
+ load_best_model_at_end = False,
464
+ metric_for_best_model = None,
465
+ greater_is_better = None,
466
+ ignore_data_skip = False,
467
+ restore_callback_states_from_checkpoint = False,
468
+ full_determinism = False,
469
+ seed = 3407,
470
+ data_seed = 3407,
471
+ use_cpu = False,
472
+ accelerator_config = None,
473
+ parallelism_config = None,
474
+ dataloader_drop_last = False,
475
+ dataloader_num_workers = 0,
476
+ dataloader_pin_memory = True,
477
+ dataloader_persistent_workers = False,
478
+ dataloader_prefetch_factor = None,
479
+ remove_unused_columns = True,
480
+ label_names = None,
481
+ train_sampling_strategy = 'random',
482
+ length_column_name = 'length',
483
+ ddp_find_unused_parameters = None,
484
+ ddp_bucket_cap_mb = None,
485
+ ddp_broadcast_buffers = None,
486
+ ddp_backend = None,
487
+ ddp_timeout = 1800,
488
+ fsdp = None,
489
+ fsdp_config = None,
490
+ deepspeed = None,
491
+ debug = '',
492
+ skip_memory_metrics = True,
493
+ do_train = False,
494
+ do_eval = False,
495
+ do_predict = False,
496
+ resume_from_checkpoint = None,
497
+ warmup_ratio = None,
498
+ logging_dir = None,
499
+ local_rank = -1,
500
+ dataset_num_proc = None,
501
+ num_mini_batches = 1,
502
+ total_episodes = None,
503
+ local_rollout_forward_batch_size = 64,
504
+ num_sample_generations = 10,
505
+ response_length = 53,
506
+ stop_token = None,
507
+ stop_token_id = None,
508
+ temperature = 0.7,
509
+ missing_eos_penalty = None,
510
+ sft_model_path = 'EleutherAI/pythia-160m',
511
+ world_size = None,
512
+ num_total_batches = None,
513
+ micro_batch_size = None,
514
+ local_batch_size = None,
515
+ batch_size = None,
516
+ local_mini_batch_size = None,
517
+ mini_batch_size = None,
518
+ exp_name = 'ppo_config',
519
+ reward_model_path = 'EleutherAI/pythia-160m',
520
+ model_adapter_name = None,
521
+ ref_adapter_name = None,
522
+ num_ppo_epochs = 4,
523
+ whiten_rewards = False,
524
+ kl_coef = 0.05,
525
+ kl_estimator = 'k1',
526
+ cliprange = 0.2,
527
+ vf_coef = 0.1,
528
+ cliprange_value = 0.2,
529
+ gamma = 1.0,
530
+ lam = 0.95,
531
+ ds3_gather_for_generation = True,
532
+ vllm_sampling_params = None,
533
+ unsloth_num_chunks = -1,
534
+ unsloth_logit_chunk_multiplier = None,
535
+ unsloth_grpo_mini_batch = None,
536
+
537
+ **kwargs,
538
+ ):
539
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
540
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
541
+ if num_train_epochs is None:
542
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
543
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
544
+ output_dir = 'unsloth_training_checkpoints'
545
+ save_strategy = 'no'
546
+ import multiprocessing as _mp
547
+ if _mp.get_start_method() != 'fork':
548
+ dataset_num_proc = None
549
+ elif dataset_num_proc is None:
550
+ import psutil
551
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
552
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
553
+ if memory_gb_left <= 2: dataset_num_proc = 1
554
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
555
+ if temperature <= 0:
556
+ raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
557
+ elif temperature >= 10:
558
+ raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
559
+
560
+
561
+ super().__init__(
562
+ output_dir = output_dir,
563
+ per_device_train_batch_size = per_device_train_batch_size,
564
+ num_train_epochs = num_train_epochs,
565
+ max_steps = max_steps,
566
+ learning_rate = learning_rate,
567
+ lr_scheduler_type = lr_scheduler_type,
568
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
569
+ warmup_steps = warmup_steps,
570
+ optim = optim,
571
+ optim_args = optim_args,
572
+ weight_decay = weight_decay,
573
+ adam_beta1 = adam_beta1,
574
+ adam_beta2 = adam_beta2,
575
+ adam_epsilon = adam_epsilon,
576
+ optim_target_modules = optim_target_modules,
577
+ gradient_accumulation_steps = gradient_accumulation_steps,
578
+ average_tokens_across_devices = average_tokens_across_devices,
579
+ max_grad_norm = max_grad_norm,
580
+ label_smoothing_factor = label_smoothing_factor,
581
+ bf16 = bf16,
582
+ fp16 = fp16,
583
+ bf16_full_eval = bf16_full_eval,
584
+ fp16_full_eval = fp16_full_eval,
585
+ tf32 = tf32,
586
+ gradient_checkpointing = gradient_checkpointing,
587
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
588
+ torch_compile = torch_compile,
589
+ torch_compile_backend = torch_compile_backend,
590
+ torch_compile_mode = torch_compile_mode,
591
+ use_liger_kernel = use_liger_kernel,
592
+ liger_kernel_config = liger_kernel_config,
593
+ use_cache = use_cache,
594
+ neftune_noise_alpha = neftune_noise_alpha,
595
+ torch_empty_cache_steps = torch_empty_cache_steps,
596
+ auto_find_batch_size = auto_find_batch_size,
597
+ logging_strategy = logging_strategy,
598
+ logging_steps = logging_steps,
599
+ logging_first_step = logging_first_step,
600
+ log_on_each_node = log_on_each_node,
601
+ logging_nan_inf_filter = logging_nan_inf_filter,
602
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
603
+ log_level = log_level,
604
+ log_level_replica = log_level_replica,
605
+ disable_tqdm = disable_tqdm,
606
+ report_to = report_to,
607
+ run_name = run_name,
608
+ project = project,
609
+ trackio_space_id = trackio_space_id,
610
+ eval_strategy = eval_strategy,
611
+ eval_steps = eval_steps,
612
+ eval_delay = eval_delay,
613
+ per_device_eval_batch_size = per_device_eval_batch_size,
614
+ prediction_loss_only = prediction_loss_only,
615
+ eval_on_start = eval_on_start,
616
+ eval_do_concat_batches = eval_do_concat_batches,
617
+ eval_use_gather_object = eval_use_gather_object,
618
+ eval_accumulation_steps = eval_accumulation_steps,
619
+ batch_eval_metrics = batch_eval_metrics,
620
+ save_only_model = save_only_model,
621
+ save_strategy = save_strategy,
622
+ save_steps = save_steps,
623
+ save_on_each_node = save_on_each_node,
624
+ save_total_limit = save_total_limit,
625
+ enable_jit_checkpoint = enable_jit_checkpoint,
626
+ push_to_hub = push_to_hub,
627
+ hub_token = hub_token,
628
+ hub_private_repo = hub_private_repo,
629
+ hub_model_id = hub_model_id,
630
+ hub_strategy = hub_strategy,
631
+ hub_always_push = hub_always_push,
632
+ hub_revision = hub_revision,
633
+ load_best_model_at_end = load_best_model_at_end,
634
+ metric_for_best_model = metric_for_best_model,
635
+ greater_is_better = greater_is_better,
636
+ ignore_data_skip = ignore_data_skip,
637
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
638
+ full_determinism = full_determinism,
639
+ seed = seed,
640
+ data_seed = data_seed,
641
+ use_cpu = use_cpu,
642
+ accelerator_config = accelerator_config,
643
+ parallelism_config = parallelism_config,
644
+ dataloader_drop_last = dataloader_drop_last,
645
+ dataloader_num_workers = dataloader_num_workers,
646
+ dataloader_pin_memory = dataloader_pin_memory,
647
+ dataloader_persistent_workers = dataloader_persistent_workers,
648
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
649
+ remove_unused_columns = remove_unused_columns,
650
+ label_names = label_names,
651
+ train_sampling_strategy = train_sampling_strategy,
652
+ length_column_name = length_column_name,
653
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
654
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
655
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
656
+ ddp_backend = ddp_backend,
657
+ ddp_timeout = ddp_timeout,
658
+ fsdp = fsdp,
659
+ fsdp_config = fsdp_config,
660
+ deepspeed = deepspeed,
661
+ debug = debug,
662
+ skip_memory_metrics = skip_memory_metrics,
663
+ do_train = do_train,
664
+ do_eval = do_eval,
665
+ do_predict = do_predict,
666
+ resume_from_checkpoint = resume_from_checkpoint,
667
+ warmup_ratio = warmup_ratio,
668
+ logging_dir = logging_dir,
669
+ local_rank = local_rank,
670
+ dataset_num_proc = dataset_num_proc,
671
+ num_mini_batches = num_mini_batches,
672
+ total_episodes = total_episodes,
673
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
674
+ num_sample_generations = num_sample_generations,
675
+ response_length = response_length,
676
+ stop_token = stop_token,
677
+ stop_token_id = stop_token_id,
678
+ temperature = temperature,
679
+ missing_eos_penalty = missing_eos_penalty,
680
+ sft_model_path = sft_model_path,
681
+ world_size = world_size,
682
+ num_total_batches = num_total_batches,
683
+ micro_batch_size = micro_batch_size,
684
+ local_batch_size = local_batch_size,
685
+ batch_size = batch_size,
686
+ local_mini_batch_size = local_mini_batch_size,
687
+ mini_batch_size = mini_batch_size,
688
+ exp_name = exp_name,
689
+ reward_model_path = reward_model_path,
690
+ model_adapter_name = model_adapter_name,
691
+ ref_adapter_name = ref_adapter_name,
692
+ num_ppo_epochs = num_ppo_epochs,
693
+ whiten_rewards = whiten_rewards,
694
+ kl_coef = kl_coef,
695
+ kl_estimator = kl_estimator,
696
+ cliprange = cliprange,
697
+ vf_coef = vf_coef,
698
+ cliprange_value = cliprange_value,
699
+ gamma = gamma,
700
+ lam = lam,
701
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
702
+ self.vllm_sampling_params = vllm_sampling_params
703
+ self.unsloth_num_chunks = unsloth_num_chunks
704
+ if unsloth_grpo_mini_batch is not None:
705
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
706
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
707
+ else:
708
+ raise ValueError(
709
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
710
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
711
+ )
712
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
713
+
714
+
715
+ pass
716
+
717
+ class _UnslothPPOTrainer(BaseTrainer):
718
+ """"""
719
+
720
+ _tag_names = ["trl", "ppo"]
721
+ _name = "PPO"
722
+ _paper = {
723
+ "title": "Fine-Tuning Language Models from Human Preferences",
724
+ "id": "1909.08593",
725
+ # docstyle-ignore
726
+ "citation": textwrap.dedent("""\
727
+ @article{mziegler2019fine-tuning,
728
+ title = {{Fine-Tuning Language Models from Human Preferences}},
729
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
730
+ year = 2019,
731
+ eprint = {arXiv:1909.08593}
732
+ }"""),
733
+ }
734
+
735
+ def __init__(
736
+ self,
737
+ args: PPOConfig,
738
+ processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
739
+ model: nn.Module,
740
+ ref_model: Optional[nn.Module],
741
+ reward_model: nn.Module,
742
+ train_dataset: Dataset,
743
+ value_model: nn.Module,
744
+ data_collator: Optional[DataCollatorWithPadding] = None,
745
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
746
+ # less commonly used
747
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
748
+ callbacks: Optional[list[TrainerCallback]] = None,
749
+ peft_config: Optional["PeftConfig"] = None,
750
+ ) -> None:
751
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
752
+ warnings.warn(
753
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
754
+ "it and want it to remain, please share your comments here: "
755
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
756
+ "TRL_EXPERIMENTAL_SILENCE=1."
757
+ )
758
+ if ref_model is model:
759
+ raise ValueError(
760
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
761
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
762
+ )
763
+
764
+ self.args = args
765
+ self.processing_class = processing_class
766
+ self.policy_model = model
767
+
768
+ # Define the collator if not provided
769
+ if data_collator is None:
770
+ data_collator = DataCollatorWithPadding(self.processing_class)
771
+
772
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
773
+ if args.stop_token and args.stop_token_id:
774
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
775
+ elif args.stop_token:
776
+ if args.stop_token == "eos":
777
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
778
+ else:
779
+ raise ValueError(
780
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
781
+ )
782
+ else:
783
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
784
+
785
+ # Check that the kl estimator is valid
786
+ if self.args.kl_estimator not in {"k1", "k3"}:
787
+ raise ValueError(
788
+ "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
789
+ "appears to be a strictly better estimator). See "
790
+ "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
791
+ )
792
+
793
+ # peft support
794
+ if not is_peft_available() and peft_config is not None:
795
+ raise ImportError(
796
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
797
+ )
798
+ elif is_peft_available() and peft_config is not None:
799
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
800
+ if isinstance(self.policy_model, PeftModel):
801
+ self.policy_model = self.policy_model.merge_and_unload()
802
+
803
+ # get peft model with the given config
804
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
805
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
806
+ peft_module_casting_to_bf16(self.policy_model)
807
+
808
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
809
+ self.model_adapter_name = args.model_adapter_name
810
+ self.ref_adapter_name = args.ref_adapter_name
811
+
812
+ if ref_model:
813
+ self.ref_model = ref_model
814
+ elif self.is_peft_model:
815
+ self.ref_model = None
816
+ else:
817
+ self.ref_model = create_reference_model(self.policy_model)
818
+
819
+ self.reward_model = reward_model
820
+ self.train_dataset = train_dataset
821
+ self.train_dataset_len = len(train_dataset)
822
+ self.value_model = value_model
823
+ self.data_collator = data_collator
824
+ self.eval_dataset = eval_dataset
825
+ self.optimizer, self.lr_scheduler = optimizers
826
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
827
+
828
+ #########
829
+ # calculate various batch sizes
830
+ #########
831
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
832
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
833
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
834
+ self.accelerator = accelerator
835
+ args.world_size = accelerator.num_processes
836
+ args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
837
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
838
+ args.batch_size = int(args.local_batch_size * args.world_size)
839
+ args.mini_batch_size = exact_div(
840
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
841
+ )
842
+ args.local_mini_batch_size = exact_div(
843
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
844
+ )
845
+ if args.whiten_rewards:
846
+ assert args.local_mini_batch_size >= 8, (
847
+ f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
848
+ )
849
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
850
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
851
+ args.num_total_batches = math.ceil(
852
+ args.total_episodes / args.batch_size
853
+ ) # we may train for more than `total_episodes`
854
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
855
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
856
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
857
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
858
+ if args.num_sample_generations > 0:
859
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
860
+ self.local_dataloader_batch_size = args.local_batch_size
861
+
862
+ #########
863
+ # setup model, optimizer, and others
864
+ #########
865
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
866
+ if module is not None:
867
+ disable_dropout_in_model(module)
868
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
869
+ self.model.config = self.policy_model.config # needed for pushing to hub
870
+ self.create_optimizer_and_scheduler(
871
+ num_training_steps=args.num_total_batches
872
+ ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
873
+
874
+ #########
875
+ # trainer specifics
876
+ #########
877
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
878
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
879
+ self.callback_handler = CallbackHandler(
880
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
881
+ )
882
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
883
+ self.control = TrainerControl()
884
+ self.state = OnlineTrainerState(
885
+ is_local_process_zero=self.is_local_process_zero(),
886
+ is_world_process_zero=self.is_world_process_zero(),
887
+ stateful_callbacks=[
888
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
889
+ ],
890
+ )
891
+ self.current_flos = 0
892
+ self.hp_search_backend = None
893
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
894
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
895
+ # Create distant repo and output directory if needed
896
+ self.hub_model_id = None
897
+ if self.args.push_to_hub:
898
+ self.init_hf_repo()
899
+ if self.args.should_save:
900
+ os.makedirs(self.args.output_dir, exist_ok=True)
901
+
902
+ # Add tags for models that have been loaded with the correct transformers version
903
+ if hasattr(self.model, "add_model_tags"):
904
+ self.model.add_model_tags(self._tag_names)
905
+
906
+ #########
907
+ # setup dataloader
908
+ #########
909
+ self.dataloader = DataLoader(
910
+ self.train_dataset,
911
+ batch_size=self.local_dataloader_batch_size,
912
+ shuffle=True,
913
+ collate_fn=self.data_collator,
914
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
915
+ )
916
+ # sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
917
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
918
+ torch.manual_seed(args.seed)
919
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
920
+ torch.manual_seed(self.local_seed) # reset the local seed again
921
+
922
+ self.eval_dataloader = DataLoader(
923
+ self.eval_dataset,
924
+ batch_size=args.per_device_eval_batch_size,
925
+ collate_fn=self.data_collator,
926
+ drop_last=True,
927
+ ) # no need to shuffle eval dataset
928
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
929
+
930
+ if self.is_deepspeed_enabled:
931
+ self.reward_model = prepare_deepspeed(
932
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
933
+ )
934
+
935
+ if self.ref_model is None:
936
+ if not self.is_peft_model:
937
+ raise ValueError("No reference model and model is not a Peft model.")
938
+ else:
939
+ self.ref_model = prepare_deepspeed(
940
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
941
+ )
942
+ else:
943
+ if self.ref_model is None:
944
+ if not self.is_peft_model:
945
+ raise ValueError("No reference model and model is not a Peft model.")
946
+ else:
947
+ self.ref_model = self.ref_model.to(self.accelerator.device)
948
+ self.reward_model = self.reward_model.to(self.accelerator.device)
949
+
950
+ def get_train_dataloader(self) -> DataLoader:
951
+ return self.dataloader
952
+
953
+ def get_eval_dataloader(self) -> DataLoader:
954
+ return self.eval_dataloader
955
+
956
+ @contextmanager
957
+ def null_ref_context(self):
958
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
959
+ with (
960
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
961
+ if self.is_peft_model and not self.ref_adapter_name
962
+ else nullcontext()
963
+ ):
964
+ if self.ref_adapter_name:
965
+ self.model.policy.set_adapter(self.ref_adapter_name)
966
+ yield
967
+ if self.ref_adapter_name:
968
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
969
+
970
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
971
+ backup_model = self.model
972
+ self.model = self.model.policy # save only the policy
973
+
974
+ if self.is_deepspeed_enabled:
975
+ backup_deepspeed = self.deepspeed
976
+ self.deepspeed = self.model
977
+
978
+ super().save_model(output_dir, _internal_call)
979
+
980
+ self.model = backup_model
981
+
982
+ if self.is_deepspeed_enabled:
983
+ self.deepspeed = backup_deepspeed
984
+
985
+ def train(self):
986
+ args = self.args
987
+ accelerator = self.accelerator
988
+ optimizer = self.optimizer
989
+ model = self.model
990
+ ref_policy = self.ref_model
991
+ reward_model = self.reward_model
992
+ processing_class = self.processing_class
993
+ dataloader = self.dataloader
994
+ device = accelerator.device
995
+
996
+ def repeat_generator():
997
+ while True:
998
+ yield from dataloader
999
+
1000
+ iter_dataloader = iter(repeat_generator())
1001
+ generation_config = GenerationConfig(
1002
+ max_new_tokens=args.response_length,
1003
+ temperature=(args.temperature + 1e-7),
1004
+ top_k=0.0,
1005
+ top_p=1.0,
1006
+ do_sample=True,
1007
+ )
1008
+
1009
+ accelerator.print("===training policy===")
1010
+ start_time = time.time()
1011
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
1012
+ approxkl_stats = torch.zeros(stats_shape, device=device)
1013
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
1014
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
1015
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
1016
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
1017
+ entropy_stats = torch.zeros(stats_shape, device=device)
1018
+ ratio_stats = torch.zeros(stats_shape, device=device)
1019
+ model.train()
1020
+
1021
+ # trainer state initialization
1022
+ self.state.global_step = 0
1023
+ self.state.episode = 0
1024
+ self.state.max_steps = args.num_total_batches
1025
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
1026
+ # Compute absolute values for logging, eval, and save if given as ratio
1027
+ if args.logging_steps is not None:
1028
+ if args.logging_steps < 1:
1029
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
1030
+ else:
1031
+ self.state.logging_steps = args.logging_steps
1032
+ if args.eval_steps is not None:
1033
+ if args.eval_steps < 1:
1034
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
1035
+ else:
1036
+ self.state.eval_steps = args.eval_steps
1037
+ if args.save_steps is not None:
1038
+ if args.save_steps < 1:
1039
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
1040
+ else:
1041
+ self.state.save_steps = args.save_steps
1042
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
1043
+
1044
+ # backward compatibility
1045
+ if self.is_deepspeed_enabled:
1046
+ self.deepspeed = self.model
1047
+ self.model_wrapped = self.model
1048
+
1049
+ for update in range(1, args.num_total_batches + 1):
1050
+ self.state.episode += 1 * args.batch_size
1051
+ data = next(iter_dataloader)
1052
+ with torch.no_grad():
1053
+ queries = data["input_ids"].to(device)
1054
+ context_length = queries.shape[1]
1055
+ responses = []
1056
+ postprocessed_responses = []
1057
+ logprobs = []
1058
+ ref_logprobs = []
1059
+ scores = []
1060
+ sequence_lengths = []
1061
+ values = []
1062
+ with unwrap_model_for_generation(
1063
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1064
+ ) as unwrapped_model:
1065
+ query_responses, logitss = batch_generation(
1066
+ unwrapped_model.policy,
1067
+ queries,
1068
+ args.local_rollout_forward_batch_size,
1069
+ processing_class.pad_token_id,
1070
+ generation_config,
1071
+ )
1072
+
1073
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
1074
+ query = queries[i : i + args.local_rollout_forward_batch_size]
1075
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
1076
+ response = query_response[:, context_length:]
1077
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
1078
+ logprob = selective_log_softmax(logits, response)
1079
+ del logits
1080
+ empty_cache()
1081
+
1082
+ if ref_policy is None:
1083
+ with self.null_ref_context():
1084
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
1085
+ else:
1086
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
1087
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
1088
+ ref_logits /= args.temperature + 1e-7
1089
+ ref_logprob = selective_log_softmax(ref_logits, response)
1090
+ del ref_output, ref_logits
1091
+ empty_cache()
1092
+
1093
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
1094
+ postprocessed_response = response
1095
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1096
+ postprocessed_response = truncate_response(
1097
+ self.stop_token_id, processing_class.pad_token_id, response
1098
+ )
1099
+
1100
+ # Response Processing 2. run reward model on the truncated responses
1101
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1102
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
1103
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
1104
+ full_value, _, _ = get_reward(
1105
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
1106
+ )
1107
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
1108
+ _, score, _ = get_reward(
1109
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1110
+ )
1111
+
1112
+ responses.append(response)
1113
+ postprocessed_responses.append(postprocessed_response)
1114
+ logprobs.append(logprob)
1115
+ ref_logprobs.append(ref_logprob)
1116
+ sequence_lengths.append(sequence_length)
1117
+ scores.append(score)
1118
+ values.append(value)
1119
+ responses = torch.cat(responses, 0)
1120
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
1121
+ logprobs = torch.cat(logprobs, 0)
1122
+ ref_logprobs = torch.cat(ref_logprobs, 0)
1123
+ sequence_lengths = torch.cat(sequence_lengths, 0)
1124
+ scores = torch.cat(scores, 0)
1125
+ values = torch.cat(values, 0)
1126
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
1127
+ empty_cache()
1128
+ gc.collect()
1129
+
1130
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
1131
+ # Completions not passing that filter will receive a lower score.
1132
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
1133
+ if self.args.missing_eos_penalty is not None:
1134
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
1135
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
1136
+
1137
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
1138
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
1139
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
1140
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
1141
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
1142
+ sequence_lengths_p1 = sequence_lengths + 1
1143
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
1144
+ values = torch.masked_fill(values, padding_mask_p1, 0)
1145
+
1146
+ # 4. compute rewards
1147
+ # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
1148
+ logr = ref_logprobs - logprobs
1149
+ kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
1150
+ non_score_reward = -args.kl_coef * kl
1151
+ rewards = non_score_reward.clone()
1152
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
1153
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
1154
+ rewards[[actual_start, actual_end]] += scores
1155
+
1156
+ # 5. whiten rewards
1157
+ if args.whiten_rewards:
1158
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
1159
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
1160
+
1161
+ # 6. compute advantages and returns
1162
+ lastgaelam = 0
1163
+ advantages_reversed = []
1164
+ gen_length = responses.shape[1]
1165
+ for t in reversed(range(gen_length)):
1166
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
1167
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
1168
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
1169
+ advantages_reversed.append(lastgaelam)
1170
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
1171
+ returns = advantages + values
1172
+ advantages = masked_whiten(advantages, ~padding_mask)
1173
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
1174
+ empty_cache()
1175
+
1176
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
1177
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
1178
+ b_inds = np.random.permutation(args.local_batch_size)
1179
+ minibatch_idx = 0
1180
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
1181
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
1182
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
1183
+ gradient_accumulation_idx = 0
1184
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
1185
+ with accelerator.accumulate(model):
1186
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
1187
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
1188
+ mb_advantage = advantages[micro_batch_inds]
1189
+ mb_responses = responses[micro_batch_inds]
1190
+ mb_query_responses = query_responses[micro_batch_inds]
1191
+ mb_logprobs = logprobs[micro_batch_inds]
1192
+ mb_return = returns[micro_batch_inds]
1193
+ mb_values = values[micro_batch_inds]
1194
+
1195
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
1196
+ logits = output.logits[:, context_length - 1 : -1]
1197
+ logits /= args.temperature + 1e-7
1198
+ new_logprobs = selective_log_softmax(logits, mb_responses)
1199
+ new_logprobs = torch.masked_fill(
1200
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
1201
+ )
1202
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
1203
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
1204
+ vpredclipped = torch.clamp(
1205
+ vpred,
1206
+ mb_values - args.cliprange_value,
1207
+ mb_values + args.cliprange_value,
1208
+ )
1209
+ vf_losses1 = torch.square(vpred - mb_return)
1210
+ vf_losses2 = torch.square(vpredclipped - mb_return)
1211
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
1212
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
1213
+ vf_clipfrac = masked_mean(
1214
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
1215
+ )
1216
+ logprobs_diff = new_logprobs - mb_logprobs
1217
+ ratio = torch.exp(logprobs_diff)
1218
+ pg_losses = -mb_advantage * ratio
1219
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
1220
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
1221
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
1222
+ loss = pg_loss + args.vf_coef * vf_loss
1223
+ accelerator.backward(loss)
1224
+ optimizer.step()
1225
+ optimizer.zero_grad()
1226
+ with torch.no_grad():
1227
+ pg_clipfrac = masked_mean(
1228
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
1229
+ )
1230
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1, dtype = torch.float32).to(logits.dtype)
1231
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
1232
+ approxkl = 0.5 * (logprobs_diff**2).mean()
1233
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
1234
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
1235
+ pg_clipfrac
1236
+ )
1237
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
1238
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
1239
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
1240
+ vf_clipfrac
1241
+ )
1242
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
1243
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
1244
+ gradient_accumulation_idx += 1
1245
+ minibatch_idx += 1
1246
+ # del everything and empty cache
1247
+ # fmt: off
1248
+ del (
1249
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
1250
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
1251
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
1252
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
1253
+ )
1254
+ # fmt: on
1255
+ empty_cache()
1256
+ with torch.no_grad():
1257
+ mean_kl = kl.sum(1).mean()
1258
+ mean_entropy = (-logprobs).sum(1).mean()
1259
+ mean_non_score_reward = non_score_reward.sum(1).mean()
1260
+ rlhf_reward = mean_non_score_reward + scores.mean()
1261
+ eps = int(self.state.episode / (time.time() - start_time))
1262
+ metrics = {}
1263
+ metrics["eps"] = eps
1264
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
1265
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
1266
+ metrics["objective/non_score_reward"] = (
1267
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
1268
+ )
1269
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
1270
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
1271
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
1272
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
1273
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
1274
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
1275
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
1276
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
1277
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
1278
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
1279
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
1280
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
1281
+ metrics["episode"] = self.state.episode
1282
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
1283
+ self.state.global_step += 1
1284
+ self.log(metrics)
1285
+
1286
+ self.lr_scheduler.step()
1287
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
1288
+ if self.control.should_save:
1289
+ self._save_checkpoint(model, trial=None)
1290
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1291
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
1292
+ empty_cache()
1293
+ gc.collect()
1294
+
1295
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
1296
+ self.generate_completions(sampling=True)
1297
+ empty_cache()
1298
+ del (
1299
+ query_responses,
1300
+ responses,
1301
+ postprocessed_responses,
1302
+ logprobs,
1303
+ ref_logprobs,
1304
+ values,
1305
+ sequence_lengths,
1306
+ contain_eos_token,
1307
+ sequence_lengths_p1,
1308
+ response_idxs,
1309
+ padding_mask,
1310
+ padding_mask_p1,
1311
+ rewards,
1312
+ actual_start,
1313
+ actual_end,
1314
+ advantages,
1315
+ returns,
1316
+ )
1317
+ empty_cache()
1318
+
1319
+ # HF trainer specifics
1320
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1321
+ if self.control.should_save:
1322
+ self._save_checkpoint(model, trial=None)
1323
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1324
+
1325
+ def generate_completions(self, sampling: bool = False):
1326
+ args = self.args
1327
+ processing_class = self.processing_class
1328
+ generation_config = GenerationConfig(
1329
+ max_new_tokens=self.args.response_length,
1330
+ temperature=(0.01 + 1e-7),
1331
+ top_k=0.0,
1332
+ top_p=1.0,
1333
+ do_sample=True,
1334
+ )
1335
+
1336
+ table = defaultdict(list)
1337
+ with unwrap_model_for_generation(
1338
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1339
+ ) as unwrapped_model:
1340
+ for batch in self.eval_dataloader:
1341
+ query = batch["input_ids"]
1342
+ with torch.no_grad():
1343
+ context_length = query.shape[1]
1344
+ query_response, _ = batch_generation(
1345
+ unwrapped_model.policy,
1346
+ query,
1347
+ query.shape[0],
1348
+ processing_class.pad_token_id,
1349
+ generation_config,
1350
+ )
1351
+ response = query_response[:, context_length:]
1352
+ postprocessed_response = response
1353
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1354
+ postprocessed_response = truncate_response(
1355
+ self.stop_token_id, processing_class.pad_token_id, response
1356
+ )
1357
+ table["query"].extend(
1358
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1359
+ )
1360
+ table["model response"].extend(
1361
+ gather_object(processing_class.batch_decode(postprocessed_response))
1362
+ )
1363
+
1364
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1365
+ _, score, _ = get_reward(
1366
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1367
+ )
1368
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1369
+
1370
+ if sampling:
1371
+ break
1372
+ df = pd.DataFrame(table)
1373
+
1374
+ if self.accelerator.is_main_process:
1375
+ if is_rich_available():
1376
+ print_rich_table(df.iloc[0 : 0 + 5])
1377
+ if "wandb" in args.report_to:
1378
+ import wandb
1379
+
1380
+ if wandb.run is not None:
1381
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1382
+
1383
+ if "comet_ml" in args.report_to:
1384
+ log_table_to_comet_experiment(
1385
+ name="completions.csv",
1386
+ table=df,
1387
+ )
1388
+
1389
+ # Ensure the model card is saved along with the checkpoint
1390
+ def _save_checkpoint(self, model, trial):
1391
+ if self.args.hub_model_id is None:
1392
+ model_name = Path(self.args.output_dir).name
1393
+ else:
1394
+ model_name = self.args.hub_model_id.split("/")[-1]
1395
+ self.create_model_card(model_name=model_name)
1396
+ super()._save_checkpoint(model, trial)
1397
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1398
+ """
1399
+ Trainer for Proximal Policy Optimization (PPO).
1400
+
1401
+ For details on PPO, see the paper: [Proximal Policy Optimization
1402
+ Algorithms](https://huggingface.co/papers/1707.06347).
1403
+
1404
+ Args:
1405
+ args ([`PPOConfig`]):
1406
+ Training arguments.
1407
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]):
1408
+ Class to process the data.
1409
+ model (`torch.nn.Module`):
1410
+ Model to be trained. This is the policy model.
1411
+ ref_model (`torch.nn.Module`, *optional*):
1412
+ Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created.
1413
+ reward_model (`torch.nn.Module`):
1414
+ Reward model used to compute the rewards.
1415
+ train_dataset ([`~datasets.Dataset`]):
1416
+ Dataset for training.
1417
+ value_model (`torch.nn.Module`):
1418
+ Value model used to predict the value of a state.
1419
+ data_collator ([`~transformers.DataCollatorWithPadding`], *optional*):
1420
+ Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created
1421
+ using the `processing_class`.
1422
+ eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
1423
+ Dataset for evaluation.
1424
+ optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
1425
+ Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the
1426
+ optimizer and the learning rate scheduler are created using the
1427
+ [`~transformers.Trainer.create_optimizer_and_scheduler`] method.
1428
+ callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
1429
+ Callbacks to use during training.
1430
+ peft_config ([`~peft.PeftConfig`], *optional*):
1431
+ PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model`
1432
+ will be wrapped with the specified PEFT adapter.
1433
+
1434
+ """
1435
+ def __init__(
1436
+ self,
1437
+ args,
1438
+ processing_class,
1439
+ model,
1440
+ ref_model,
1441
+ reward_model,
1442
+ train_dataset,
1443
+ value_model,
1444
+ data_collator = None,
1445
+ eval_dataset = None,
1446
+ callbacks = None,
1447
+ peft_config = None,
1448
+ **kwargs
1449
+ ):
1450
+ if args is None: args = UnslothPPOConfig()
1451
+ use_bf16 = getattr(args, 'bf16', False)
1452
+ if type(use_bf16) is not bool: use_bf16 = False
1453
+ use_fp16 = getattr(args, 'fp16', False)
1454
+ if type(use_fp16) is not bool: use_fp16 = False
1455
+ force_float32 = False
1456
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1457
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1458
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1459
+ force_float32 = True
1460
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1461
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1462
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1463
+ from unsloth_zoo.utils import _get_dtype
1464
+ dtype = _get_dtype(dtype)
1465
+ float16 = dtype == torch.float16
1466
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1467
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1468
+ if force_float32:
1469
+ # Forced float32 training
1470
+ args.fp16 = False
1471
+ args.bf16 = False
1472
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1473
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1474
+ # args.mixed_precision is a new argument which needs to be set now
1475
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1476
+ # Mixed precision training
1477
+ args.fp16 = float16
1478
+ args.bf16 = not float16
1479
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1480
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1481
+ # args.mixed_precision is a new argument which needs to be set now
1482
+ elif mixed_precision_dtype == 'bfloat16':
1483
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1484
+ args.fp16 = False
1485
+ args.bf16 = False
1486
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1487
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1488
+ # args.mixed_precision is a new argument which needs to be set now
1489
+
1490
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1491
+ args.eval_strategy = 'steps'
1492
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1493
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1494
+ if ga_steps is not None and ga_steps > 1:
1495
+ from transformers import __version__ as transformers_version
1496
+ if Version(transformers_version) <= Version('4.45.2'):
1497
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1498
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1499
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1500
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1501
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1502
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1503
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1504
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1505
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1506
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1507
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1508
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1509
+ if force_float32:
1510
+ args.bf16_full_eval = False
1511
+ args.fp16_full_eval = False
1512
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1513
+ args.bf16_full_eval = True
1514
+ args.fp16_full_eval = False
1515
+ elif not bf16_full_eval and not fp16_full_eval:
1516
+ args.bf16_full_eval = args.bf16
1517
+ args.fp16_full_eval = args.fp16
1518
+ _output_logits = False
1519
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1520
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1521
+ if _output_logits:
1522
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1523
+ if model is not None:
1524
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1525
+ if _warnings_issued is None:
1526
+ model.warnings_issued = {}
1527
+ elif not isinstance(_warnings_issued, dict):
1528
+ try:
1529
+ model.warnings_issued = dict(_warnings_issued)
1530
+ except Exception:
1531
+ model.warnings_issued = {}
1532
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1533
+ pass
1534
+ else:
1535
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1536
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1537
+ if args_max_seq_length is None and model_max_seq_length is not None:
1538
+ max_seq_length = model.max_seq_length
1539
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1540
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1541
+ if args_max_seq_length > model_max_seq_length:
1542
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1543
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1544
+ args.max_seq_length = model_max_seq_length
1545
+ if model is not None and hasattr(model, 'for_training'):
1546
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1547
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1548
+ if 'processing_class' in locals():
1549
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1550
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1551
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1552
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1553
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1554
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1555
+ data_collator = TransformersDataCollatorForLanguageModeling(
1556
+ __tokenizer,
1557
+ mlm = False,
1558
+ mlm_probability = 0.0,
1559
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1560
+ )
1561
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1562
+ data_collator = DataCollatorForSeq2Seq(
1563
+ __tokenizer,
1564
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1565
+ )
1566
+ else:
1567
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1568
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1569
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1570
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1571
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1572
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1573
+ data_collator = DataCollatorForSeq2Seq(
1574
+ __tokenizer.tokenizer,
1575
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1576
+ )
1577
+ else:
1578
+ data_collator = TransformersDataCollatorForLanguageModeling(
1579
+ __tokenizer.tokenizer,
1580
+ mlm = False,
1581
+ mlm_probability = 0.0,
1582
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1583
+ )
1584
+ other_metrics = []
1585
+
1586
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1587
+ PatchRLStatistics('ppo_trainer', other_metrics)
1588
+
1589
+ # [TODO] Fix up DataParallel multiplying batch sizes
1590
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1591
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1592
+ if getattr(args, "_n_gpu", 1) != 1:
1593
+ args._n_gpu = 1
1594
+ if "model" in locals() and hasattr(model, "for_training"):
1595
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1596
+ super().__init__(
1597
+ args = args,
1598
+ processing_class = processing_class,
1599
+ model = model,
1600
+ ref_model = ref_model,
1601
+ reward_model = reward_model,
1602
+ train_dataset = train_dataset,
1603
+ value_model = value_model,
1604
+ data_collator = data_collator,
1605
+ eval_dataset = eval_dataset,
1606
+ callbacks = callbacks,
1607
+ peft_config = peft_config,**kwargs)
1608
+ if "model" in locals() and hasattr(model, "for_inference"):
1609
+ model.for_inference()
1610
+ if hasattr(self, 'neftune_hook_handle'):
1611
+ self.neftune_hook_handle.remove()
1612
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1613
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1614
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1615
+ pass
1616
+ if hasattr(self, 'accelerator'):
1617
+ scaler = self.accelerator.scaler
1618
+ current_model = model
1619
+ while hasattr(current_model, 'model'):
1620
+ current_model.accelerator_scaler = scaler
1621
+ current_model = current_model.model
1622
+ current_model.accelerator_scaler = scaler
1623
+ pass
1624
+ if hasattr(self, 'train'):
1625
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1626
+ pass
1627
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1628
+ _vllm_tok = self.llm.get_tokenizer()
1629
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1630
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1631
+ _vllm_tok.chat_template = _pc.chat_template
1632
+ pass
1633
+
1634
+ pass
unsloth_compiled_cache/UnslothPRMTrainer.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.prm_trainer import (BaseImageProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, Path, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, nn, os, textwrap, torch, warnings, BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PartialState, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, compute_accuracy, disable_dropout_in_model, features, nn, os, torch, warnings, PreTrainedModel, os, torch)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothPRMConfig(PRMConfig):
325
+ """
326
+
327
+ Configuration class for the [`PRMTrainer`].
328
+
329
+ This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
330
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
331
+ differ from those in [`~transformers.TrainingArguments`].
332
+
333
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
334
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
335
+ command line.
336
+
337
+ Parameters:
338
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
339
+ Maximum length of the sequences (prompt + completion) used for truncation.
340
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
341
+ Maximum length of the prompt used for truncation.
342
+ max_completion_length (`int`, *optional*):
343
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
344
+ disable_dropout (`bool`, *optional*, defaults to `True`):
345
+ Whether to disable dropout in the model.
346
+ step_separator (`str`, *optional*, defaults to `"\n"`):
347
+ Separator used to separate each step of the reasoning process.
348
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
349
+ Whether to train only on the last step.
350
+ dataset_num_proc (`int`, *optional*):
351
+ Number of processes to use for processing the dataset.
352
+
353
+ """
354
+ vllm_sampling_params: Optional[Any] = field(
355
+ default = None,
356
+ metadata = {'help': 'vLLM SamplingParams'},
357
+ )
358
+ unsloth_num_chunks : Optional[int] = field(
359
+ default = -1,
360
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
361
+ )
362
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
363
+ default = None,
364
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
365
+ )
366
+ unsloth_grpo_mini_batch : Optional[int] = field(
367
+ default = None,
368
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
369
+ )
370
+ max_seq_length : Optional[int] = field(
371
+ default = None,
372
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
373
+ )
374
+ def __init__(
375
+ self,
376
+ output_dir = None,
377
+ per_device_train_batch_size = 4,
378
+ num_train_epochs = 3.0,
379
+ max_steps = -1,
380
+ learning_rate = 5e-05,
381
+ lr_scheduler_type = 'linear',
382
+ lr_scheduler_kwargs = None,
383
+ warmup_steps = 0.1,
384
+ optim = 'adamw_8bit',
385
+ optim_args = None,
386
+ weight_decay = 0.01,
387
+ adam_beta1 = 0.9,
388
+ adam_beta2 = 0.999,
389
+ adam_epsilon = 1e-08,
390
+ optim_target_modules = None,
391
+ gradient_accumulation_steps = 2,
392
+ average_tokens_across_devices = True,
393
+ max_grad_norm = 1.0,
394
+ label_smoothing_factor = 0.0,
395
+ bf16 = False,
396
+ fp16 = False,
397
+ bf16_full_eval = False,
398
+ fp16_full_eval = False,
399
+ tf32 = None,
400
+ gradient_checkpointing = True,
401
+ gradient_checkpointing_kwargs = None,
402
+ torch_compile = False,
403
+ torch_compile_backend = None,
404
+ torch_compile_mode = None,
405
+ use_liger_kernel = False,
406
+ liger_kernel_config = None,
407
+ use_cache = False,
408
+ neftune_noise_alpha = None,
409
+ torch_empty_cache_steps = 250,
410
+ auto_find_batch_size = False,
411
+ logging_strategy = 'steps',
412
+ logging_steps = 1,
413
+ logging_first_step = False,
414
+ log_on_each_node = True,
415
+ logging_nan_inf_filter = False,
416
+ include_num_input_tokens_seen = False,
417
+ log_level = 'passive',
418
+ log_level_replica = 'warning',
419
+ disable_tqdm = None,
420
+ report_to = 'none',
421
+ run_name = None,
422
+ project = 'huggingface',
423
+ trackio_space_id = 'trackio',
424
+ eval_strategy = 'no',
425
+ eval_steps = None,
426
+ eval_delay = 0,
427
+ per_device_eval_batch_size = 4,
428
+ prediction_loss_only = False,
429
+ eval_on_start = False,
430
+ eval_do_concat_batches = True,
431
+ eval_use_gather_object = False,
432
+ eval_accumulation_steps = 2,
433
+ batch_eval_metrics = False,
434
+ save_only_model = False,
435
+ save_strategy = 'steps',
436
+ save_steps = 500,
437
+ save_on_each_node = False,
438
+ save_total_limit = None,
439
+ enable_jit_checkpoint = False,
440
+ push_to_hub = False,
441
+ hub_token = None,
442
+ hub_private_repo = None,
443
+ hub_model_id = None,
444
+ hub_strategy = 'every_save',
445
+ hub_always_push = False,
446
+ hub_revision = None,
447
+ load_best_model_at_end = False,
448
+ metric_for_best_model = None,
449
+ greater_is_better = None,
450
+ ignore_data_skip = False,
451
+ restore_callback_states_from_checkpoint = False,
452
+ full_determinism = False,
453
+ seed = 3407,
454
+ data_seed = 3407,
455
+ use_cpu = False,
456
+ accelerator_config = None,
457
+ parallelism_config = None,
458
+ dataloader_drop_last = False,
459
+ dataloader_num_workers = 0,
460
+ dataloader_pin_memory = True,
461
+ dataloader_persistent_workers = False,
462
+ dataloader_prefetch_factor = None,
463
+ remove_unused_columns = True,
464
+ label_names = None,
465
+ train_sampling_strategy = 'random',
466
+ length_column_name = 'length',
467
+ ddp_find_unused_parameters = None,
468
+ ddp_bucket_cap_mb = None,
469
+ ddp_broadcast_buffers = None,
470
+ ddp_backend = None,
471
+ ddp_timeout = 1800,
472
+ fsdp = None,
473
+ fsdp_config = None,
474
+ deepspeed = None,
475
+ debug = '',
476
+ skip_memory_metrics = True,
477
+ do_train = False,
478
+ do_eval = False,
479
+ do_predict = False,
480
+ resume_from_checkpoint = None,
481
+ warmup_ratio = None,
482
+ logging_dir = None,
483
+ local_rank = -1,
484
+ max_length = 1024,
485
+ max_prompt_length = 512,
486
+ max_completion_length = None,
487
+ disable_dropout = True,
488
+ step_separator = '\
489
+ ',
490
+ train_on_last_step_only = False,
491
+ dataset_num_proc = None,
492
+ vllm_sampling_params = None,
493
+ unsloth_num_chunks = -1,
494
+ unsloth_logit_chunk_multiplier = None,
495
+ unsloth_grpo_mini_batch = None,
496
+ max_seq_length = None,
497
+ **kwargs,
498
+ ):
499
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
500
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
501
+ if num_train_epochs is None:
502
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
503
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
504
+ output_dir = 'unsloth_training_checkpoints'
505
+ save_strategy = 'no'
506
+ import multiprocessing as _mp
507
+ if _mp.get_start_method() != 'fork':
508
+ dataset_num_proc = None
509
+ elif dataset_num_proc is None:
510
+ import psutil
511
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
512
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
513
+ if memory_gb_left <= 2: dataset_num_proc = 1
514
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
515
+
516
+ super().__init__(
517
+ output_dir = output_dir,
518
+ per_device_train_batch_size = per_device_train_batch_size,
519
+ num_train_epochs = num_train_epochs,
520
+ max_steps = max_steps,
521
+ learning_rate = learning_rate,
522
+ lr_scheduler_type = lr_scheduler_type,
523
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
524
+ warmup_steps = warmup_steps,
525
+ optim = optim,
526
+ optim_args = optim_args,
527
+ weight_decay = weight_decay,
528
+ adam_beta1 = adam_beta1,
529
+ adam_beta2 = adam_beta2,
530
+ adam_epsilon = adam_epsilon,
531
+ optim_target_modules = optim_target_modules,
532
+ gradient_accumulation_steps = gradient_accumulation_steps,
533
+ average_tokens_across_devices = average_tokens_across_devices,
534
+ max_grad_norm = max_grad_norm,
535
+ label_smoothing_factor = label_smoothing_factor,
536
+ bf16 = bf16,
537
+ fp16 = fp16,
538
+ bf16_full_eval = bf16_full_eval,
539
+ fp16_full_eval = fp16_full_eval,
540
+ tf32 = tf32,
541
+ gradient_checkpointing = gradient_checkpointing,
542
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
543
+ torch_compile = torch_compile,
544
+ torch_compile_backend = torch_compile_backend,
545
+ torch_compile_mode = torch_compile_mode,
546
+ use_liger_kernel = use_liger_kernel,
547
+ liger_kernel_config = liger_kernel_config,
548
+ use_cache = use_cache,
549
+ neftune_noise_alpha = neftune_noise_alpha,
550
+ torch_empty_cache_steps = torch_empty_cache_steps,
551
+ auto_find_batch_size = auto_find_batch_size,
552
+ logging_strategy = logging_strategy,
553
+ logging_steps = logging_steps,
554
+ logging_first_step = logging_first_step,
555
+ log_on_each_node = log_on_each_node,
556
+ logging_nan_inf_filter = logging_nan_inf_filter,
557
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
558
+ log_level = log_level,
559
+ log_level_replica = log_level_replica,
560
+ disable_tqdm = disable_tqdm,
561
+ report_to = report_to,
562
+ run_name = run_name,
563
+ project = project,
564
+ trackio_space_id = trackio_space_id,
565
+ eval_strategy = eval_strategy,
566
+ eval_steps = eval_steps,
567
+ eval_delay = eval_delay,
568
+ per_device_eval_batch_size = per_device_eval_batch_size,
569
+ prediction_loss_only = prediction_loss_only,
570
+ eval_on_start = eval_on_start,
571
+ eval_do_concat_batches = eval_do_concat_batches,
572
+ eval_use_gather_object = eval_use_gather_object,
573
+ eval_accumulation_steps = eval_accumulation_steps,
574
+ batch_eval_metrics = batch_eval_metrics,
575
+ save_only_model = save_only_model,
576
+ save_strategy = save_strategy,
577
+ save_steps = save_steps,
578
+ save_on_each_node = save_on_each_node,
579
+ save_total_limit = save_total_limit,
580
+ enable_jit_checkpoint = enable_jit_checkpoint,
581
+ push_to_hub = push_to_hub,
582
+ hub_token = hub_token,
583
+ hub_private_repo = hub_private_repo,
584
+ hub_model_id = hub_model_id,
585
+ hub_strategy = hub_strategy,
586
+ hub_always_push = hub_always_push,
587
+ hub_revision = hub_revision,
588
+ load_best_model_at_end = load_best_model_at_end,
589
+ metric_for_best_model = metric_for_best_model,
590
+ greater_is_better = greater_is_better,
591
+ ignore_data_skip = ignore_data_skip,
592
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
593
+ full_determinism = full_determinism,
594
+ seed = seed,
595
+ data_seed = data_seed,
596
+ use_cpu = use_cpu,
597
+ accelerator_config = accelerator_config,
598
+ parallelism_config = parallelism_config,
599
+ dataloader_drop_last = dataloader_drop_last,
600
+ dataloader_num_workers = dataloader_num_workers,
601
+ dataloader_pin_memory = dataloader_pin_memory,
602
+ dataloader_persistent_workers = dataloader_persistent_workers,
603
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
604
+ remove_unused_columns = remove_unused_columns,
605
+ label_names = label_names,
606
+ train_sampling_strategy = train_sampling_strategy,
607
+ length_column_name = length_column_name,
608
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
609
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
610
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
611
+ ddp_backend = ddp_backend,
612
+ ddp_timeout = ddp_timeout,
613
+ fsdp = fsdp,
614
+ fsdp_config = fsdp_config,
615
+ deepspeed = deepspeed,
616
+ debug = debug,
617
+ skip_memory_metrics = skip_memory_metrics,
618
+ do_train = do_train,
619
+ do_eval = do_eval,
620
+ do_predict = do_predict,
621
+ resume_from_checkpoint = resume_from_checkpoint,
622
+ warmup_ratio = warmup_ratio,
623
+ logging_dir = logging_dir,
624
+ local_rank = local_rank,
625
+ max_length = max_length,
626
+ max_prompt_length = max_prompt_length,
627
+ max_completion_length = max_completion_length,
628
+ disable_dropout = disable_dropout,
629
+ step_separator = step_separator,
630
+ train_on_last_step_only = train_on_last_step_only,
631
+ dataset_num_proc = dataset_num_proc,**kwargs)
632
+ self.vllm_sampling_params = vllm_sampling_params
633
+ self.unsloth_num_chunks = unsloth_num_chunks
634
+ if unsloth_grpo_mini_batch is not None:
635
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
636
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
637
+ else:
638
+ raise ValueError(
639
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
640
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
641
+ )
642
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
643
+ self.max_seq_length = max_seq_length
644
+
645
+ pass
646
+
647
+ class _UnslothPRMTrainer(BaseTrainer):
648
+ """"""
649
+
650
+ _tag_names = ["trl", "prm"]
651
+ _name = "PRM"
652
+ _paper = {
653
+ "title": "Solving math word problems with process-and outcome-based feedback",
654
+ "id": "2211.14275",
655
+ # docstyle-ignore
656
+ "citation": textwrap.dedent("""\
657
+ @article{uesato2022solving,
658
+ title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
659
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
660
+ year = 2022,
661
+ journal = {arXiv preprint arXiv:2211.14275}
662
+ }"""),
663
+ }
664
+
665
+ def __init__(
666
+ self,
667
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
668
+ args: Optional[PRMConfig] = None,
669
+ data_collator: Optional[DataCollator] = None,
670
+ train_dataset: Optional[Dataset] = None,
671
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
672
+ processing_class: Optional[
673
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
674
+ ] = None,
675
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
676
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
677
+ callbacks: Optional[list[TrainerCallback]] = None,
678
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
679
+ None,
680
+ None,
681
+ ),
682
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
683
+ peft_config: Optional[dict] = None,
684
+ ):
685
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
686
+ warnings.warn(
687
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
688
+ "it and want it to remain, please share your comments here: "
689
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
690
+ "TRL_EXPERIMENTAL_SILENCE=1."
691
+ )
692
+ if False:
693
+ pass
694
+
695
+ # Disable dropout in the model
696
+ if args.disable_dropout:
697
+ disable_dropout_in_model(model)
698
+
699
+ if compute_metrics is None:
700
+ compute_metrics = compute_accuracy
701
+
702
+ if data_collator is None:
703
+ if processing_class is None:
704
+ raise ValueError(
705
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
706
+ )
707
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
708
+
709
+ if "input_ids" not in train_dataset.column_names:
710
+ with PartialState().main_process_first():
711
+ fn_kwargs = {
712
+ "tokenizer": processing_class,
713
+ "step_separator": args.step_separator,
714
+ "max_length": args.max_length,
715
+ "max_prompt_length": args.max_prompt_length,
716
+ "max_completion_length": args.max_completion_length,
717
+ "train_on_last_step_only": args.train_on_last_step_only,
718
+ }
719
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
720
+ train_dataset = train_dataset.map(
721
+ self.tokenize_row,
722
+ fn_kwargs=train_fn_kwargs,
723
+ num_proc=args.dataset_num_proc,
724
+ remove_columns=train_dataset.features,
725
+ desc="Tokenizing train dataset",
726
+ features=features.Features( # needed to avoid map to cast labels to bool
727
+ {
728
+ "labels": features.Sequence(features.Value("int64")),
729
+ "input_ids": features.Sequence(features.Value("int64")),
730
+ }
731
+ ),
732
+ )
733
+
734
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
735
+ if eval_dataset is not None:
736
+ eval_dataset = eval_dataset.map(
737
+ self.tokenize_row,
738
+ fn_kwargs=eval_fn_kwargs,
739
+ num_proc=args.dataset_num_proc,
740
+ remove_columns=eval_dataset.features,
741
+ desc="Tokenizing eval dataset",
742
+ features=features.Features( # needed to avoid map to cast labels to bool
743
+ {
744
+ "labels": features.Sequence(features.Value("int64")),
745
+ "input_ids": features.Sequence(features.Value("int64")),
746
+ }
747
+ ),
748
+ )
749
+
750
+ super().__init__(
751
+ model=model,
752
+ args=args,
753
+ data_collator=data_collator,
754
+ train_dataset=train_dataset,
755
+ eval_dataset=eval_dataset,
756
+ processing_class=processing_class,
757
+ model_init=model_init,
758
+ compute_metrics=compute_metrics,
759
+ callbacks=callbacks,
760
+ optimizers=optimizers,
761
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
762
+ )
763
+
764
+ # Add tags for models that have been loaded with the correct transformers version
765
+ if hasattr(self.model, "add_model_tags"):
766
+ self.model.add_model_tags(self._tag_names)
767
+
768
+ @staticmethod
769
+ def tokenize_row(
770
+ features,
771
+ tokenizer,
772
+ step_separator,
773
+ max_length,
774
+ max_prompt_length,
775
+ max_completion_length,
776
+ train_on_last_step_only,
777
+ is_eval,
778
+ ):
779
+ r"""
780
+ Tokenize a row of the dataset.
781
+
782
+ Args:
783
+ features (`dict[str, str]`):
784
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
785
+ tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
786
+ Tokenizer used to process the data.
787
+ step_separator (`str`):
788
+ Separator between steps in the completion.
789
+ max_length (`int` or `None`):
790
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
791
+ max_prompt_length (`int` or `None`):
792
+ Maximum length of the prompt. If `None`, the prompt is not truncated.
793
+ max_completion_length (`int` or `None`):
794
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
795
+ train_on_last_step_only (`bool`):
796
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
797
+ token of the completion.
798
+ is_eval (`bool`):
799
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if
800
+ `train_on_last_step_only` is set to `True`.
801
+
802
+ Returns:
803
+ `dict[str, list[int]]`:
804
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
805
+
806
+ Example:
807
+ ```python
808
+ >>> from transformers import AutoTokenizer
809
+
810
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
811
+ >>> features = {
812
+ ... "prompt": "Which number is larger, 9.8 or 9.11?",
813
+ ... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
814
+ ... "labels": [True, False],
815
+ ... }
816
+ >>> PRMTrainer.tokenize_row(
817
+ ... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False
818
+ ... )
819
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
820
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
821
+ ```
822
+ """
823
+ # Tokenize the prompt and completions
824
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
825
+ completions_ids = [
826
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
827
+ ]
828
+ if train_on_last_step_only and not is_eval:
829
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
830
+ else:
831
+ labels = [int(label) for label in features["labels"]]
832
+
833
+ # Get the ID of the separator token and add it to the completions
834
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
835
+ completions_ids = [completion + separator_ids for completion in completions_ids]
836
+
837
+ # Create the label
838
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
839
+
840
+ # Join the completions and labels steps
841
+ completion_ids = list(chain(*completions_ids))
842
+ labels = list(chain(*labels))
843
+
844
+ if tokenizer.bos_token_id is not None:
845
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
846
+
847
+ # Truncate prompt and completion sequences
848
+ if max_prompt_length is not None:
849
+ prompt_ids = prompt_ids[-max_prompt_length:]
850
+ if max_completion_length is not None:
851
+ completion_ids = completion_ids[:max_completion_length]
852
+ labels = labels[:max_completion_length]
853
+
854
+ input_ids = prompt_ids + completion_ids
855
+ labels = [-100] * len(prompt_ids) + labels
856
+
857
+ if max_length is not None:
858
+ input_ids = input_ids[:max_length]
859
+ labels = labels[:max_length]
860
+
861
+ return {"input_ids": input_ids, "labels": labels}
862
+
863
+ # Ensure the model card is saved along with the checkpoint
864
+ def _save_checkpoint(self, model, trial):
865
+ if self.args.hub_model_id is None:
866
+ model_name = Path(self.args.output_dir).name
867
+ else:
868
+ model_name = self.args.hub_model_id.split("/")[-1]
869
+ self.create_model_card(model_name=model_name)
870
+ super()._save_checkpoint(model, trial)
871
+ class UnslothPRMTrainer(_UnslothPRMTrainer):
872
+ """
873
+
874
+ Initialize PRMTrainer.
875
+
876
+ Args:
877
+ model ([`~transformers.PreTrainedModel`]):
878
+ The model to train, preferably an `AutoModelForTokenClassification`.
879
+ args ([`PRMConfig`]):
880
+ The arguments to use for training.
881
+ data_collator ([`~transformers.DataCollator`]):
882
+ The data collator to use for training. If None is specified, the default data collator
883
+ ([`~transformers.DataCollatorForTokenClassification`]) will be used which will pad the sequences to the
884
+ maximum length of the sequences in the batch, given a dataset of paired sequences.
885
+ train_dataset ([`~datasets.Dataset`]):
886
+ The dataset to use for training.
887
+ eval_dataset ([`~datasets.Dataset`]):
888
+ The dataset to use for evaluation.
889
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
890
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
891
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
892
+ reuse the fine-tuned model.
893
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
894
+ The model initializer to use for training. If None is specified, the default model initializer will be
895
+ used.
896
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
897
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
898
+ will be used.
899
+ callbacks (`list[transformers.TrainerCallback]`):
900
+ The callbacks to use for training.
901
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
902
+ The optimizer and scheduler to use for training.
903
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
904
+ The function to use to preprocess the logits before computing the metrics.
905
+ peft_config (`dict`, defaults to `None`):
906
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
907
+ a PEFT model.
908
+
909
+ """
910
+ def __init__(
911
+ self,
912
+ model = None,
913
+ args = None,
914
+ data_collator = None,
915
+ train_dataset = None,
916
+ eval_dataset = None,
917
+ processing_class = None,
918
+ model_init = None,
919
+ compute_metrics = None,
920
+ callbacks = None,
921
+ preprocess_logits_for_metrics = None,
922
+ peft_config = None,
923
+ **kwargs
924
+ ):
925
+ if args is None: args = UnslothPRMConfig()
926
+ use_bf16 = getattr(args, 'bf16', False)
927
+ if type(use_bf16) is not bool: use_bf16 = False
928
+ use_fp16 = getattr(args, 'fp16', False)
929
+ if type(use_fp16) is not bool: use_fp16 = False
930
+ force_float32 = False
931
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
932
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
933
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
934
+ force_float32 = True
935
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
936
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
937
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
938
+ from unsloth_zoo.utils import _get_dtype
939
+ dtype = _get_dtype(dtype)
940
+ float16 = dtype == torch.float16
941
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
942
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
943
+ if force_float32:
944
+ # Forced float32 training
945
+ args.fp16 = False
946
+ args.bf16 = False
947
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
948
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
949
+ # args.mixed_precision is a new argument which needs to be set now
950
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
951
+ # Mixed precision training
952
+ args.fp16 = float16
953
+ args.bf16 = not float16
954
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
955
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
956
+ # args.mixed_precision is a new argument which needs to be set now
957
+ elif mixed_precision_dtype == 'bfloat16':
958
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
959
+ args.fp16 = False
960
+ args.bf16 = False
961
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
962
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
963
+ # args.mixed_precision is a new argument which needs to be set now
964
+
965
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
966
+ args.eval_strategy = 'steps'
967
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
968
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
969
+ if ga_steps is not None and ga_steps > 1:
970
+ from transformers import __version__ as transformers_version
971
+ if Version(transformers_version) <= Version('4.45.2'):
972
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
973
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
974
+ if getattr(args, 'eval_strategy', 'no') != 'no':
975
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
976
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
977
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
978
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
979
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
980
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
981
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
982
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
983
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
984
+ if force_float32:
985
+ args.bf16_full_eval = False
986
+ args.fp16_full_eval = False
987
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
988
+ args.bf16_full_eval = True
989
+ args.fp16_full_eval = False
990
+ elif not bf16_full_eval and not fp16_full_eval:
991
+ args.bf16_full_eval = args.bf16
992
+ args.fp16_full_eval = args.fp16
993
+ _output_logits = False
994
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
995
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
996
+ if _output_logits:
997
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
998
+ if model is not None:
999
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1000
+ if _warnings_issued is None:
1001
+ model.warnings_issued = {}
1002
+ elif not isinstance(_warnings_issued, dict):
1003
+ try:
1004
+ model.warnings_issued = dict(_warnings_issued)
1005
+ except Exception:
1006
+ model.warnings_issued = {}
1007
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1008
+ pass
1009
+ else:
1010
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1011
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1012
+ if args_max_seq_length is None and model_max_seq_length is not None:
1013
+ max_seq_length = model.max_seq_length
1014
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1015
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1016
+ if args_max_seq_length > model_max_seq_length:
1017
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1018
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1019
+ args.max_seq_length = model_max_seq_length
1020
+ if model is not None and hasattr(model, 'for_training'):
1021
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1022
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1023
+ if 'processing_class' in locals():
1024
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1025
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1026
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1027
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1028
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1029
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1030
+ data_collator = TransformersDataCollatorForLanguageModeling(
1031
+ __tokenizer,
1032
+ mlm = False,
1033
+ mlm_probability = 0.0,
1034
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1035
+ )
1036
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1037
+ data_collator = DataCollatorForSeq2Seq(
1038
+ __tokenizer,
1039
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1040
+ )
1041
+ else:
1042
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1043
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1044
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1045
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1046
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1047
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1048
+ data_collator = DataCollatorForSeq2Seq(
1049
+ __tokenizer.tokenizer,
1050
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1051
+ )
1052
+ else:
1053
+ data_collator = TransformersDataCollatorForLanguageModeling(
1054
+ __tokenizer.tokenizer,
1055
+ mlm = False,
1056
+ mlm_probability = 0.0,
1057
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1058
+ )
1059
+ other_metrics = []
1060
+
1061
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1062
+ PatchRLStatistics('prm_trainer', other_metrics)
1063
+
1064
+ # [TODO] Fix up DataParallel multiplying batch sizes
1065
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1066
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1067
+ if getattr(args, "_n_gpu", 1) != 1:
1068
+ args._n_gpu = 1
1069
+ if "model" in locals() and hasattr(model, "for_training"):
1070
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1071
+ super().__init__(
1072
+ model = model,
1073
+ args = args,
1074
+ data_collator = data_collator,
1075
+ train_dataset = train_dataset,
1076
+ eval_dataset = eval_dataset,
1077
+ processing_class = processing_class,
1078
+ model_init = model_init,
1079
+ compute_metrics = compute_metrics,
1080
+ callbacks = callbacks,
1081
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1082
+ peft_config = peft_config,**kwargs)
1083
+ if "model" in locals() and hasattr(model, "for_inference"):
1084
+ model.for_inference()
1085
+ if hasattr(self, 'neftune_hook_handle'):
1086
+ self.neftune_hook_handle.remove()
1087
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1088
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1089
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1090
+ pass
1091
+ if hasattr(self, 'accelerator'):
1092
+ scaler = self.accelerator.scaler
1093
+ current_model = model
1094
+ while hasattr(current_model, 'model'):
1095
+ current_model.accelerator_scaler = scaler
1096
+ current_model = current_model.model
1097
+ current_model.accelerator_scaler = scaler
1098
+ pass
1099
+ if hasattr(self, 'train'):
1100
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1101
+ pass
1102
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1103
+ _vllm_tok = self.llm.get_tokenizer()
1104
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1105
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1106
+ _vllm_tok.chat_template = _pc.chat_template
1107
+ pass
1108
+
1109
+ pass
unsloth_compiled_cache/UnslothRewardTrainer.py ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.reward_trainer import (Any, AutoModelForSequenceClassification, AutoTokenizer, BaseTrainer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PartialState, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, RewardTrainer, TrainerCallback, Union, clone_chat_template, contextlib, dataclass, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pad, re, remove_none_values, suppress_from_pretrained_warning, torch, transformers, Any, AutoModelForSequenceClassification, AutoTokenizer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, TrainerCallback, Union, clone_chat_template, contextlib, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, logger, os, pad, re, suppress_from_pretrained_warning, torch, transformers, PreTrainedModel, logger, os, re, torch)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothRewardConfig(RewardConfig):
325
+ """
326
+
327
+ Configuration class for the [`RewardTrainer`].
328
+
329
+ This class includes only the parameters that are specific to Reward training. For a full list of training
330
+ arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
331
+ class may differ from those in [`~transformers.TrainingArguments`].
332
+
333
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
334
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
335
+ command line.
336
+
337
+ Parameters:
338
+ > Parameters that control the model
339
+
340
+ model_init_kwargs (`dict[str, Any]`, *optional*):
341
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
342
+ argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want
343
+ to include the load balancing/auxilliary loss as a part of the final loss, remember to set
344
+ `output_router_logits=True` in this dictionary.
345
+ chat_template_path (`str`, *optional*):
346
+ If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
347
+ or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
348
+ ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
349
+ embedding layer is resized accordingly.
350
+ disable_dropout (`bool`, *optional*, defaults to `True`):
351
+ Whether to disable dropout in the model.
352
+
353
+ > Parameters that control the data preprocessing
354
+
355
+ dataset_num_proc (`int`, *optional*):
356
+ Number of processes to use for processing the dataset.
357
+ eos_token (`str`, *optional*):
358
+ Token used to indicate the end of a turn or sequence. If `None`, it defaults to
359
+ `processing_class.eos_token`.
360
+ pad_token (`str`, *optional*):
361
+ Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
362
+ it falls back to `processing_class.eos_token`.
363
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
364
+ Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
365
+ exceeds this value. If `None`, no filtering is applied.
366
+ pad_to_multiple_of (`int`, *optional*):
367
+ If set, the sequences will be padded to a multiple of this value.
368
+
369
+ > Parameters that control the training
370
+
371
+ center_rewards_coefficient (`float`, *optional*):
372
+ Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
373
+ https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
374
+ activation_offloading (`bool`, *optional*, defaults to `False`):
375
+ Whether to offload the activations to the CPU.
376
+
377
+ """
378
+ vllm_sampling_params: Optional[Any] = field(
379
+ default = None,
380
+ metadata = {'help': 'vLLM SamplingParams'},
381
+ )
382
+ unsloth_num_chunks : Optional[int] = field(
383
+ default = -1,
384
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
385
+ )
386
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
387
+ default = None,
388
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
389
+ )
390
+ unsloth_grpo_mini_batch : Optional[int] = field(
391
+ default = None,
392
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
393
+ )
394
+ max_seq_length : Optional[int] = field(
395
+ default = None,
396
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
397
+ )
398
+ def __init__(
399
+ self,
400
+ output_dir = None,
401
+ per_device_train_batch_size = 4,
402
+ num_train_epochs = 3.0,
403
+ max_steps = -1,
404
+ learning_rate = 5e-05,
405
+ lr_scheduler_type = 'linear',
406
+ lr_scheduler_kwargs = None,
407
+ warmup_steps = 0.1,
408
+ optim = 'adamw_8bit',
409
+ optim_args = None,
410
+ weight_decay = 0.01,
411
+ adam_beta1 = 0.9,
412
+ adam_beta2 = 0.999,
413
+ adam_epsilon = 1e-08,
414
+ optim_target_modules = None,
415
+ gradient_accumulation_steps = 2,
416
+ average_tokens_across_devices = True,
417
+ max_grad_norm = 1.0,
418
+ label_smoothing_factor = 0.0,
419
+ bf16 = False,
420
+ fp16 = False,
421
+ bf16_full_eval = False,
422
+ fp16_full_eval = False,
423
+ tf32 = None,
424
+ gradient_checkpointing = True,
425
+ gradient_checkpointing_kwargs = None,
426
+ torch_compile = False,
427
+ torch_compile_backend = None,
428
+ torch_compile_mode = None,
429
+ use_liger_kernel = False,
430
+ liger_kernel_config = None,
431
+ use_cache = False,
432
+ neftune_noise_alpha = None,
433
+ torch_empty_cache_steps = 250,
434
+ auto_find_batch_size = False,
435
+ logging_strategy = 'steps',
436
+ logging_steps = 1,
437
+ logging_first_step = False,
438
+ log_on_each_node = True,
439
+ logging_nan_inf_filter = False,
440
+ include_num_input_tokens_seen = False,
441
+ log_level = 'passive',
442
+ log_level_replica = 'warning',
443
+ disable_tqdm = None,
444
+ report_to = 'none',
445
+ run_name = None,
446
+ project = 'huggingface',
447
+ trackio_space_id = 'trackio',
448
+ eval_strategy = 'no',
449
+ eval_steps = None,
450
+ eval_delay = 0,
451
+ per_device_eval_batch_size = 4,
452
+ prediction_loss_only = False,
453
+ eval_on_start = False,
454
+ eval_do_concat_batches = True,
455
+ eval_use_gather_object = False,
456
+ eval_accumulation_steps = 2,
457
+ batch_eval_metrics = False,
458
+ save_only_model = False,
459
+ save_strategy = 'steps',
460
+ save_steps = 500,
461
+ save_on_each_node = False,
462
+ save_total_limit = None,
463
+ enable_jit_checkpoint = False,
464
+ push_to_hub = False,
465
+ hub_token = None,
466
+ hub_private_repo = None,
467
+ hub_model_id = None,
468
+ hub_strategy = 'every_save',
469
+ hub_always_push = False,
470
+ hub_revision = None,
471
+ load_best_model_at_end = False,
472
+ metric_for_best_model = None,
473
+ greater_is_better = None,
474
+ ignore_data_skip = False,
475
+ restore_callback_states_from_checkpoint = False,
476
+ full_determinism = False,
477
+ seed = 3407,
478
+ data_seed = 3407,
479
+ use_cpu = False,
480
+ accelerator_config = None,
481
+ parallelism_config = None,
482
+ dataloader_drop_last = False,
483
+ dataloader_num_workers = 0,
484
+ dataloader_pin_memory = True,
485
+ dataloader_persistent_workers = False,
486
+ dataloader_prefetch_factor = None,
487
+ remove_unused_columns = True,
488
+ label_names = None,
489
+ train_sampling_strategy = 'random',
490
+ length_column_name = 'length',
491
+ ddp_find_unused_parameters = None,
492
+ ddp_bucket_cap_mb = None,
493
+ ddp_broadcast_buffers = None,
494
+ ddp_backend = None,
495
+ ddp_timeout = 1800,
496
+ fsdp = None,
497
+ fsdp_config = None,
498
+ deepspeed = None,
499
+ debug = '',
500
+ skip_memory_metrics = True,
501
+ do_train = False,
502
+ do_eval = False,
503
+ do_predict = False,
504
+ resume_from_checkpoint = None,
505
+ warmup_ratio = None,
506
+ logging_dir = None,
507
+ local_rank = -1,
508
+ model_init_kwargs = None,
509
+ chat_template_path = None,
510
+ disable_dropout = True,
511
+ dataset_num_proc = None,
512
+ eos_token = None,
513
+ pad_token = None,
514
+ max_length = 1024,
515
+ pad_to_multiple_of = None,
516
+ center_rewards_coefficient = None,
517
+ activation_offloading = False,
518
+ vllm_sampling_params = None,
519
+ unsloth_num_chunks = -1,
520
+ unsloth_logit_chunk_multiplier = None,
521
+ unsloth_grpo_mini_batch = None,
522
+ max_seq_length = None,
523
+ **kwargs,
524
+ ):
525
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
526
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
527
+ if num_train_epochs is None:
528
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
529
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
530
+ output_dir = 'unsloth_training_checkpoints'
531
+ save_strategy = 'no'
532
+ import multiprocessing as _mp
533
+ if _mp.get_start_method() != 'fork':
534
+ dataset_num_proc = None
535
+ elif dataset_num_proc is None:
536
+ import psutil
537
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
538
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
539
+ if memory_gb_left <= 2: dataset_num_proc = 1
540
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
541
+ if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
542
+ from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
543
+ if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
544
+ from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
545
+ pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
546
+
547
+
548
+ super().__init__(
549
+ output_dir = output_dir,
550
+ per_device_train_batch_size = per_device_train_batch_size,
551
+ num_train_epochs = num_train_epochs,
552
+ max_steps = max_steps,
553
+ learning_rate = learning_rate,
554
+ lr_scheduler_type = lr_scheduler_type,
555
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
556
+ warmup_steps = warmup_steps,
557
+ optim = optim,
558
+ optim_args = optim_args,
559
+ weight_decay = weight_decay,
560
+ adam_beta1 = adam_beta1,
561
+ adam_beta2 = adam_beta2,
562
+ adam_epsilon = adam_epsilon,
563
+ optim_target_modules = optim_target_modules,
564
+ gradient_accumulation_steps = gradient_accumulation_steps,
565
+ average_tokens_across_devices = average_tokens_across_devices,
566
+ max_grad_norm = max_grad_norm,
567
+ label_smoothing_factor = label_smoothing_factor,
568
+ bf16 = bf16,
569
+ fp16 = fp16,
570
+ bf16_full_eval = bf16_full_eval,
571
+ fp16_full_eval = fp16_full_eval,
572
+ tf32 = tf32,
573
+ gradient_checkpointing = gradient_checkpointing,
574
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
575
+ torch_compile = torch_compile,
576
+ torch_compile_backend = torch_compile_backend,
577
+ torch_compile_mode = torch_compile_mode,
578
+ use_liger_kernel = use_liger_kernel,
579
+ liger_kernel_config = liger_kernel_config,
580
+ use_cache = use_cache,
581
+ neftune_noise_alpha = neftune_noise_alpha,
582
+ torch_empty_cache_steps = torch_empty_cache_steps,
583
+ auto_find_batch_size = auto_find_batch_size,
584
+ logging_strategy = logging_strategy,
585
+ logging_steps = logging_steps,
586
+ logging_first_step = logging_first_step,
587
+ log_on_each_node = log_on_each_node,
588
+ logging_nan_inf_filter = logging_nan_inf_filter,
589
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
590
+ log_level = log_level,
591
+ log_level_replica = log_level_replica,
592
+ disable_tqdm = disable_tqdm,
593
+ report_to = report_to,
594
+ run_name = run_name,
595
+ project = project,
596
+ trackio_space_id = trackio_space_id,
597
+ eval_strategy = eval_strategy,
598
+ eval_steps = eval_steps,
599
+ eval_delay = eval_delay,
600
+ per_device_eval_batch_size = per_device_eval_batch_size,
601
+ prediction_loss_only = prediction_loss_only,
602
+ eval_on_start = eval_on_start,
603
+ eval_do_concat_batches = eval_do_concat_batches,
604
+ eval_use_gather_object = eval_use_gather_object,
605
+ eval_accumulation_steps = eval_accumulation_steps,
606
+ batch_eval_metrics = batch_eval_metrics,
607
+ save_only_model = save_only_model,
608
+ save_strategy = save_strategy,
609
+ save_steps = save_steps,
610
+ save_on_each_node = save_on_each_node,
611
+ save_total_limit = save_total_limit,
612
+ enable_jit_checkpoint = enable_jit_checkpoint,
613
+ push_to_hub = push_to_hub,
614
+ hub_token = hub_token,
615
+ hub_private_repo = hub_private_repo,
616
+ hub_model_id = hub_model_id,
617
+ hub_strategy = hub_strategy,
618
+ hub_always_push = hub_always_push,
619
+ hub_revision = hub_revision,
620
+ load_best_model_at_end = load_best_model_at_end,
621
+ metric_for_best_model = metric_for_best_model,
622
+ greater_is_better = greater_is_better,
623
+ ignore_data_skip = ignore_data_skip,
624
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
625
+ full_determinism = full_determinism,
626
+ seed = seed,
627
+ data_seed = data_seed,
628
+ use_cpu = use_cpu,
629
+ accelerator_config = accelerator_config,
630
+ parallelism_config = parallelism_config,
631
+ dataloader_drop_last = dataloader_drop_last,
632
+ dataloader_num_workers = dataloader_num_workers,
633
+ dataloader_pin_memory = dataloader_pin_memory,
634
+ dataloader_persistent_workers = dataloader_persistent_workers,
635
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
636
+ remove_unused_columns = remove_unused_columns,
637
+ label_names = label_names,
638
+ train_sampling_strategy = train_sampling_strategy,
639
+ length_column_name = length_column_name,
640
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
641
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
642
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
643
+ ddp_backend = ddp_backend,
644
+ ddp_timeout = ddp_timeout,
645
+ fsdp = fsdp,
646
+ fsdp_config = fsdp_config,
647
+ deepspeed = deepspeed,
648
+ debug = debug,
649
+ skip_memory_metrics = skip_memory_metrics,
650
+ do_train = do_train,
651
+ do_eval = do_eval,
652
+ do_predict = do_predict,
653
+ resume_from_checkpoint = resume_from_checkpoint,
654
+ warmup_ratio = warmup_ratio,
655
+ logging_dir = logging_dir,
656
+ local_rank = local_rank,
657
+ model_init_kwargs = model_init_kwargs,
658
+ chat_template_path = chat_template_path,
659
+ disable_dropout = disable_dropout,
660
+ dataset_num_proc = dataset_num_proc,
661
+ eos_token = eos_token,
662
+ pad_token = pad_token,
663
+ max_length = max_length,
664
+ pad_to_multiple_of = pad_to_multiple_of,
665
+ center_rewards_coefficient = center_rewards_coefficient,
666
+ activation_offloading = activation_offloading,**kwargs)
667
+ self.vllm_sampling_params = vllm_sampling_params
668
+ self.unsloth_num_chunks = unsloth_num_chunks
669
+ if unsloth_grpo_mini_batch is not None:
670
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
671
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
672
+ else:
673
+ raise ValueError(
674
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
675
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
676
+ )
677
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
678
+ self.max_seq_length = max_seq_length
679
+
680
+ pass
681
+
682
+ class _UnslothRewardTrainer(BaseTrainer):
683
+ """"""
684
+
685
+ _tag_names = ["trl", "reward-trainer"]
686
+ _name = "Reward"
687
+ _template_file = "rm_model_card.md"
688
+
689
+ def __init__(
690
+ self,
691
+ model: Union[str, PreTrainedModel],
692
+ args: Optional[RewardConfig] = None,
693
+ data_collator: Optional[DataCollator] = None,
694
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
695
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
696
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
697
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
698
+ callbacks: Optional[list[TrainerCallback]] = None,
699
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
700
+ optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
701
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
702
+ peft_config: Optional["PeftConfig"] = None,
703
+ ):
704
+ # Args
705
+ if args is None:
706
+ model_name = model if isinstance(model, str) else model.config._name_or_path
707
+ model_name = model_name.split("/")[-1]
708
+ args = RewardConfig(f"{model_name}-Reward")
709
+
710
+ # Model
711
+ model_init_kwargs = args.model_init_kwargs or {}
712
+ if isinstance(model, str):
713
+ model_id = model
714
+ dtype = model_init_kwargs.get("dtype")
715
+ if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
716
+ pass # dtype is already a torch.dtype or "auto" or None
717
+ elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
718
+ model_init_kwargs["dtype"] = getattr(torch, dtype)
719
+ else:
720
+ raise ValueError(
721
+ "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
722
+ f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
723
+ )
724
+ with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
725
+ model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
726
+ else:
727
+ model_id = model.config._name_or_path
728
+ if args.model_init_kwargs is not None:
729
+ logger.warning(
730
+ "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
731
+ "The `model_init_kwargs` will be ignored."
732
+ )
733
+
734
+ # Processing class
735
+ if processing_class is None:
736
+ processing_class = AutoTokenizer.from_pretrained(model_id)
737
+
738
+ # Handle pad token for processors or tokenizers
739
+ if args.eos_token is not None:
740
+ eos_token = args.eos_token
741
+ eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
742
+ if eos_token_id is None:
743
+ raise ValueError(
744
+ f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
745
+ f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
746
+ "in the vocabulary before using it as an EOS token."
747
+ )
748
+ processing_class.eos_token_id = eos_token_id
749
+
750
+ if args.chat_template_path is not None:
751
+ if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
752
+ with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
753
+ processing_class.chat_template = chat_template_file.read()
754
+ added_tokens = []
755
+ else:
756
+ model, processing_class, added_tokens = clone_chat_template(
757
+ model, processing_class, args.chat_template_path
758
+ )
759
+ else:
760
+ added_tokens = []
761
+
762
+ # PEFT configuration and model wrapping
763
+ if False:
764
+ if added_tokens:
765
+ # Ensure that the added tokens are trainable
766
+ if peft_config.trainable_token_indices is None:
767
+ peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
768
+ elif "embed_tokens" not in peft_config.trainable_token_indices:
769
+ peft_config.trainable_token_indices["embed_tokens"] = added_tokens
770
+ else:
771
+ peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
772
+
773
+ # Ensure that the lm_head is trainable
774
+ if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
775
+ logger.warning(
776
+ "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
777
+ "`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
778
+ "tokens, leading to degraded generation quality. To fix this, add "
779
+ "`modules_to_save=['lm_head']` to your PEFT configuration."
780
+ )
781
+
782
+ if peft_config.modules_to_save is None:
783
+ peft_config.modules_to_save = ["lm_head"]
784
+ else:
785
+ peft_config.modules_to_save.append("lm_head")
786
+
787
+ if False:
788
+ pass
789
+
790
+ # Disable dropout in the model
791
+ if args.disable_dropout:
792
+ disable_dropout_in_model(model)
793
+
794
+ # Pad token [needed for SequenceClassification models]
795
+ # If not provided, use the one from the processing class or the eos token if the processing class does not have
796
+ # a pad token.
797
+ pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
798
+ pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
799
+ if pad_token_id is None:
800
+ raise ValueError(
801
+ f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
802
+ f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
803
+ "in the vocabulary before using it as a padding token."
804
+ )
805
+ model.config.pad_token_id = pad_token_id
806
+ processing_class.pad_token_id = pad_token_id
807
+
808
+ # Data collator
809
+ if data_collator is None:
810
+ data_collator = DataCollatorForPreference(
811
+ pad_token_id=pad_token_id,
812
+ pad_to_multiple_of=args.pad_to_multiple_of,
813
+ )
814
+
815
+ # Dataset
816
+ train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
817
+ if eval_dataset is not None:
818
+ if isinstance(eval_dataset, dict):
819
+ eval_dataset = {
820
+ key: self._prepare_dataset(dataset, processing_class, args, key)
821
+ for key, dataset in eval_dataset.items()
822
+ }
823
+ else:
824
+ eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
825
+
826
+ # Initialize the metrics
827
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
828
+ self._total_train_tokens = 0
829
+
830
+ # Initialize the Trainer. Parent class will handle:
831
+ # - DeepSpeed configuration [through create_accelerator_and_postprocess]
832
+ # - FSDP setup
833
+ # - Distributed training setup
834
+ # - Optimizer and scheduler creation
835
+
836
+ super().__init__(
837
+ model=model,
838
+ args=args,
839
+ data_collator=data_collator,
840
+ train_dataset=train_dataset,
841
+ eval_dataset=eval_dataset,
842
+ processing_class=processing_class,
843
+ compute_metrics=compute_metrics,
844
+ callbacks=callbacks,
845
+ optimizers=optimizers,
846
+ optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
847
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
848
+ )
849
+
850
+ # During evaluation, Trainer calls compute_loss[] only if can_return_loss is True and label_names is empty.
851
+ self.can_return_loss = True
852
+ self.label_names = []
853
+
854
+ # Initialize activation offloading context
855
+ if self.args.activation_offloading:
856
+ self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
857
+ else:
858
+ self.maybe_activation_offload_context = contextlib.nullcontext()
859
+
860
+ # Add tags for models that have been loaded with the correct transformers version
861
+ if hasattr(self.model, "add_model_tags"):
862
+ self.model.add_model_tags(self._tag_names)
863
+
864
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
865
+
866
+ def _prepare_dataset(
867
+ self,
868
+ dataset: Union[Dataset, IterableDataset],
869
+ processing_class: PreTrainedTokenizerBase,
870
+ args: RewardConfig,
871
+ dataset_name: str,
872
+ ) -> Union[Dataset, IterableDataset]:
873
+ # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
874
+ # sampled data.
875
+ if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform`
876
+ dataset = dataset.with_transform(remove_none_values)
877
+
878
+ # If the dataset is already preprocessed (tokenized), skip the processing steps.
879
+ column_names = list(next(iter(dataset)).keys())
880
+ is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names
881
+
882
+ # Build the kwargs for the `map` function
883
+ map_kwargs = {}
884
+ if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
885
+ map_kwargs["num_proc"] = args.dataset_num_proc
886
+
887
+ with PartialState().main_process_first():
888
+ if not is_processed:
889
+ # Add EOS token to the end of the sequences if needed
890
+ first_example = next(iter(dataset))
891
+ if not is_conversational(first_example):
892
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
893
+ map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
894
+
895
+ def add_eos(example, eos_token):
896
+ if not example["chosen"].endswith(eos_token):
897
+ example["chosen"] = example["chosen"] + eos_token
898
+ if "rejected" in example and not example["rejected"].endswith(eos_token):
899
+ example["rejected"] = example["rejected"] + eos_token
900
+ return example
901
+
902
+ dataset = dataset.map(
903
+ add_eos,
904
+ fn_kwargs={"eos_token": processing_class.eos_token},
905
+ **map_kwargs,
906
+ )
907
+
908
+ # Tokenize the dataset
909
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
910
+ map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
911
+
912
+ def tokenize_fn(example, processing_class):
913
+ if "prompt" in example: # explicit prompt case
914
+ example["chosen"] = example["prompt"] + example["chosen"]
915
+ example["rejected"] = example["prompt"] + example["rejected"]
916
+
917
+ if is_conversational(example):
918
+ chosen_input_ids = processing_class.apply_chat_template(
919
+ example["chosen"],
920
+ tools=example.get("tools"),
921
+ **example.get("chat_template_kwargs", {}),
922
+ )
923
+ rejected_input_ids = processing_class.apply_chat_template(
924
+ example["rejected"],
925
+ tools=example.get("tools"),
926
+ **example.get("chat_template_kwargs", {}),
927
+ )
928
+ output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids}
929
+ else:
930
+ output = {
931
+ "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"],
932
+ "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"],
933
+ }
934
+ return output
935
+
936
+ dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
937
+
938
+ # Filter samples that are longer than `max_length`
939
+ if args.max_length is not None:
940
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
941
+ map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens"
942
+ dataset = dataset.filter(
943
+ lambda example: len(example["chosen_input_ids"]) <= args.max_length
944
+ and len(example["rejected_input_ids"]) <= args.max_length,
945
+ **map_kwargs,
946
+ )
947
+
948
+ return dataset
949
+
950
+ def _set_signature_columns_if_needed(self):
951
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
952
+ # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
953
+ # and "attention_mask").
954
+ if self._signature_columns is None:
955
+ self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"]
956
+
957
+ def compute_loss(
958
+ self,
959
+ model: nn.Module,
960
+ inputs: dict[str, Union[torch.Tensor, Any]],
961
+ return_outputs: bool = False,
962
+ num_items_in_batch: Optional[torch.Tensor] = None,
963
+ ):
964
+ """
965
+ Compute training loss and additionally compute token accuracies
966
+ """
967
+ mode = "train" if self.model.training else "eval"
968
+
969
+ # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
970
+ inputs["use_cache"] = False
971
+ outputs = model(**inputs)
972
+
973
+ # Split the rewards into chosen and rejected
974
+ rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2)
975
+
976
+ # Calculate loss, optionally modulate with margin
977
+ if "margin" in inputs:
978
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
979
+ else:
980
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
981
+
982
+ if self.args.center_rewards_coefficient is not None:
983
+ loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
984
+
985
+ if mode == "train":
986
+ num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
987
+ self._total_train_tokens += num_tokens_in_batch
988
+ self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
989
+
990
+ # Compute min, mean, max, accuracy and margin
991
+ with torch.no_grad():
992
+ all_rewards = self.accelerator.gather(outputs.logits)
993
+ self._metrics[mode]["min_reward"].append(all_rewards.min().item())
994
+ self._metrics[mode]["mean_reward"].append(all_rewards.mean().item())
995
+ self._metrics[mode]["max_reward"].append(all_rewards.max().item())
996
+
997
+ mean_accuracy = (rewards_chosen > rewards_rejected).float().mean()
998
+ mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item()
999
+ self._metrics[mode]["accuracy"].append(mean_accuracy)
1000
+
1001
+ mean_margin = (rewards_chosen - rewards_rejected).mean()
1002
+ mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean()
1003
+ self._metrics[mode]["margin"].append(mean_margin.item())
1004
+
1005
+ return (loss, outputs) if return_outputs else loss
1006
+
1007
+ # Override training step to add activation offloading context.
1008
+ def training_step(self, *args, **kwargs):
1009
+ with self.maybe_activation_offload_context:
1010
+ return super().training_step(*args, **kwargs)
1011
+
1012
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1013
+ mode = "train" if self.model.training else "eval"
1014
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
1015
+
1016
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1017
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1018
+ if mode == "eval":
1019
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1020
+
1021
+ logs.update(metrics)
1022
+ super().log(logs, start_time)
1023
+ self._metrics[mode].clear()
1024
+
1025
+ # Ensure the model card is saved along with the checkpoint
1026
+ def _save_checkpoint(self, model, trial):
1027
+ if self.args.hub_model_id is None:
1028
+ model_name = Path(self.args.output_dir).name
1029
+ else:
1030
+ model_name = self.args.hub_model_id.split("/")[-1]
1031
+ self.create_model_card(model_name=model_name)
1032
+ super()._save_checkpoint(model, trial)
1033
+ class UnslothRewardTrainer(_UnslothRewardTrainer):
1034
+ """
1035
+
1036
+ Trainer for Outcome-supervised Reward Models (ORM).
1037
+
1038
+ This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
1039
+
1040
+ Example:
1041
+
1042
+ ```python
1043
+ from trl import RewardTrainer
1044
+ from datasets import load_dataset
1045
+
1046
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
1047
+
1048
+ trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset)
1049
+ trainer.train()
1050
+ ```
1051
+
1052
+ Args:
1053
+ model (`Union[str, PreTrainedModel]`):
1054
+ Model to be trained. Can be either:
1055
+
1056
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
1057
+ path to a *directory* containing model weights saved using
1058
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1059
+ using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
1060
+ `args.model_init_kwargs`.
1061
+ - A sequence classification [`~transformers.PreTrainedModel`] object.
1062
+ args ([`RewardConfig`], *optional*):
1063
+ Configuration for this trainer. If `None`, a default configuration is used.
1064
+ data_collator ([`~transformers.DataCollator`], *optional*):
1065
+ Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
1066
+ Will default to [`~trainer.reward_trainer.DataCollatorForPreference`].
1067
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1068
+ Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and
1069
+ explicit prompt). The format of the samples can be either:
1070
+
1071
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
1072
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1073
+ and content).
1074
+
1075
+ The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and
1076
+ `rejected_input_ids` fields.
1077
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1078
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1079
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*):
1080
+ Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with
1081
+ [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be
1082
+ set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the
1083
+ default.
1084
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1085
+ The function that will be used to compute metrics at evaluation. Must take a
1086
+ [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
1087
+ [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a
1088
+ boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the
1089
+ function needs to calculate and return the global summary statistics rather than accumulating the
1090
+ batch-level statistics.
1091
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*):
1092
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
1093
+ in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1094
+
1095
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1096
+ method.
1097
+ optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
1098
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
1099
+ model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
1100
+ optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
1101
+ A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
1102
+ `args`. Incompatible with the `optimizers` argument.
1103
+
1104
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
1105
+ initializing the Trainer.
1106
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
1107
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
1108
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
1109
+ by this function will be reflected in the predictions received by `compute_metrics`.
1110
+
1111
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
1112
+ peft_config ([`~peft.PeftConfig`], *optional*):
1113
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
1114
+ model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
1115
+ to ensure that the reward head is properly trained.
1116
+
1117
+ """
1118
+ def __init__(
1119
+ self,
1120
+ model,
1121
+ args = None,
1122
+ data_collator = None,
1123
+ train_dataset = None,
1124
+ eval_dataset = None,
1125
+ processing_class = None,
1126
+ compute_metrics = None,
1127
+ callbacks = None,
1128
+ optimizer_cls_and_kwargs = None,
1129
+ preprocess_logits_for_metrics = None,
1130
+ peft_config = None,
1131
+ **kwargs
1132
+ ):
1133
+ if args is None: args = UnslothRewardConfig()
1134
+ use_bf16 = getattr(args, 'bf16', False)
1135
+ if type(use_bf16) is not bool: use_bf16 = False
1136
+ use_fp16 = getattr(args, 'fp16', False)
1137
+ if type(use_fp16) is not bool: use_fp16 = False
1138
+ force_float32 = False
1139
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1140
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1141
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1142
+ force_float32 = True
1143
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1144
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1145
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1146
+ from unsloth_zoo.utils import _get_dtype
1147
+ dtype = _get_dtype(dtype)
1148
+ float16 = dtype == torch.float16
1149
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1150
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1151
+ if force_float32:
1152
+ # Forced float32 training
1153
+ args.fp16 = False
1154
+ args.bf16 = False
1155
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1156
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1157
+ # args.mixed_precision is a new argument which needs to be set now
1158
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1159
+ # Mixed precision training
1160
+ args.fp16 = float16
1161
+ args.bf16 = not float16
1162
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1163
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1164
+ # args.mixed_precision is a new argument which needs to be set now
1165
+ elif mixed_precision_dtype == 'bfloat16':
1166
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1167
+ args.fp16 = False
1168
+ args.bf16 = False
1169
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1170
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1171
+ # args.mixed_precision is a new argument which needs to be set now
1172
+
1173
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1174
+ args.eval_strategy = 'steps'
1175
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1176
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1177
+ if ga_steps is not None and ga_steps > 1:
1178
+ from transformers import __version__ as transformers_version
1179
+ if Version(transformers_version) <= Version('4.45.2'):
1180
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1181
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1182
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1183
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1184
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1185
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1186
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1187
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1188
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1189
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1190
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1191
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1192
+ if force_float32:
1193
+ args.bf16_full_eval = False
1194
+ args.fp16_full_eval = False
1195
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1196
+ args.bf16_full_eval = True
1197
+ args.fp16_full_eval = False
1198
+ elif not bf16_full_eval and not fp16_full_eval:
1199
+ args.bf16_full_eval = args.bf16
1200
+ args.fp16_full_eval = args.fp16
1201
+ _output_logits = False
1202
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1203
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1204
+ if _output_logits:
1205
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1206
+ if model is not None:
1207
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1208
+ if _warnings_issued is None:
1209
+ model.warnings_issued = {}
1210
+ elif not isinstance(_warnings_issued, dict):
1211
+ try:
1212
+ model.warnings_issued = dict(_warnings_issued)
1213
+ except Exception:
1214
+ model.warnings_issued = {}
1215
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1216
+ pass
1217
+ else:
1218
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1219
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1220
+ if args_max_seq_length is None and model_max_seq_length is not None:
1221
+ max_seq_length = model.max_seq_length
1222
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1223
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1224
+ if args_max_seq_length > model_max_seq_length:
1225
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1226
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1227
+ args.max_seq_length = model_max_seq_length
1228
+ if model is not None and hasattr(model, 'for_training'):
1229
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1230
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1231
+ if 'processing_class' in locals():
1232
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1233
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1234
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1235
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1236
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1237
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1238
+ data_collator = TransformersDataCollatorForLanguageModeling(
1239
+ __tokenizer,
1240
+ mlm = False,
1241
+ mlm_probability = 0.0,
1242
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1243
+ )
1244
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1245
+ data_collator = DataCollatorForSeq2Seq(
1246
+ __tokenizer,
1247
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1248
+ )
1249
+ else:
1250
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1251
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1252
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1253
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1254
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1255
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1256
+ data_collator = DataCollatorForSeq2Seq(
1257
+ __tokenizer.tokenizer,
1258
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1259
+ )
1260
+ else:
1261
+ data_collator = TransformersDataCollatorForLanguageModeling(
1262
+ __tokenizer.tokenizer,
1263
+ mlm = False,
1264
+ mlm_probability = 0.0,
1265
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1266
+ )
1267
+ other_metrics = []
1268
+
1269
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1270
+ PatchRLStatistics('reward_trainer', other_metrics)
1271
+
1272
+ # [TODO] Fix up DataParallel multiplying batch sizes
1273
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1274
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1275
+ if getattr(args, "_n_gpu", 1) != 1:
1276
+ args._n_gpu = 1
1277
+ if "model" in locals() and hasattr(model, "for_training"):
1278
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1279
+ super().__init__(
1280
+ model = model,
1281
+ args = args,
1282
+ data_collator = data_collator,
1283
+ train_dataset = train_dataset,
1284
+ eval_dataset = eval_dataset,
1285
+ processing_class = processing_class,
1286
+ compute_metrics = compute_metrics,
1287
+ callbacks = callbacks,
1288
+ optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1289
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1290
+ peft_config = peft_config,**kwargs)
1291
+ if "model" in locals() and hasattr(model, "for_inference"):
1292
+ model.for_inference()
1293
+ if hasattr(self, 'neftune_hook_handle'):
1294
+ self.neftune_hook_handle.remove()
1295
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1296
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1297
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1298
+ pass
1299
+ if hasattr(self, 'accelerator'):
1300
+ scaler = self.accelerator.scaler
1301
+ current_model = model
1302
+ while hasattr(current_model, 'model'):
1303
+ current_model.accelerator_scaler = scaler
1304
+ current_model = current_model.model
1305
+ current_model.accelerator_scaler = scaler
1306
+ pass
1307
+ if hasattr(self, 'train'):
1308
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1309
+ pass
1310
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1311
+ _vllm_tok = self.llm.get_tokenizer()
1312
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1313
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1314
+ _vllm_tok.chat_template = _pc.chat_template
1315
+ pass
1316
+
1317
+ pass
1318
+
1319
+
1320
+ if hasattr(logger, "addFilter"):
1321
+ import logging
1322
+ class HideLoggingMessage(logging.Filter):
1323
+ def __init__(self, text): self.text = text
1324
+ def filter(self, x): return not (self.text in x.getMessage())
1325
+ pass
1326
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
1327
+
unsloth_compiled_cache/UnslothSFTTrainer.py ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.sft_trainer import (Any, AutoProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, dataclass, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pack_dataset, pad, selective_log_softmax, torch, Any, AutoProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, os, pad, torch, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_dataset, pad, PreTrainedModel, logger, os, torch, os)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothSFTConfig(SFTConfig):
325
+ """
326
+
327
+ Configuration class for the [`SFTTrainer`].
328
+
329
+ This class includes only the parameters that are specific to SFT training. For a full list of training arguments,
330
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
331
+ differ from those in [`~transformers.TrainingArguments`].
332
+
333
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
334
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
335
+ command line.
336
+
337
+ Parameters:
338
+ > Parameters that control the model
339
+
340
+ model_init_kwargs (`dict[str, Any]`, *optional*):
341
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
342
+ argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to
343
+ include the load balancing/auxilliary loss as a part of the final loss, remember to set
344
+ `output_router_logits=True` in this dictionary.
345
+ chat_template_path (`str`, *optional*):
346
+ If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
347
+ or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
348
+ ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
349
+ embedding layer is resized accordingly.
350
+
351
+ > Parameters that control the data preprocessing
352
+
353
+ dataset_text_field (`str`, *optional*, defaults to `"text"`):
354
+ Name of the column that contains text data in the dataset.
355
+ dataset_kwargs (`dict[str, Any]`, *optional*):
356
+ Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
357
+ `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True`
358
+ regardless of the provided value, since preprocessing is done on the fly.
359
+ dataset_num_proc (`int`, *optional*):
360
+ Number of processes to use for processing the dataset.
361
+ eos_token (`str`, *optional*):
362
+ Token used to indicate the end of a turn or sequence. If `None`, it defaults to
363
+ `processing_class.eos_token`.
364
+ pad_token (`str`, *optional*):
365
+ Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
366
+ it falls back to `processing_class.eos_token`.
367
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
368
+ Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
369
+ If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
370
+ packing (`bool`, *optional*, defaults to `False`):
371
+ Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce
372
+ padding. Uses `max_length` to define sequence length.
373
+ packing_strategy (`str`, *optional*, defaults to `"bfd"`):
374
+ Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`.
375
+ padding_free (`bool`, *optional*, defaults to `False`):
376
+ Whether to perform forward passes without padding by flattening all sequences in the batch into a single
377
+ continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
378
+ supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When
379
+ packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this
380
+ parameter.
381
+ pad_to_multiple_of (`int`, *optional*):
382
+ If set, the sequences will be padded to a multiple of this value.
383
+ eval_packing (`bool`, *optional*):
384
+ Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
385
+
386
+ > Parameters that control the training
387
+
388
+ completion_only_loss (`bool`, *optional*):
389
+ Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed
390
+ only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If
391
+ `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
392
+ loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full
393
+ sequence for [language modeling](#language-modeling) datasets.
394
+ assistant_only_loss (`bool`, *optional*, defaults to `False`):
395
+ Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only
396
+ on the assistant responses, which is supported only for [conversational](#conversational) datasets. If
397
+ `False`, loss is computed on the entire sequence.
398
+ loss_type (`str`, *optional*, defaults to `"nll"`):
399
+ Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic
400
+ Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)).
401
+ activation_offloading (`bool`, *optional*, defaults to `False`):
402
+ Whether to offload the activations to the CPU.
403
+
404
+ """
405
+ vllm_sampling_params: Optional[Any] = field(
406
+ default = None,
407
+ metadata = {'help': 'vLLM SamplingParams'},
408
+ )
409
+ unsloth_num_chunks : Optional[int] = field(
410
+ default = -1,
411
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
412
+ )
413
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
414
+ default = None,
415
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
416
+ )
417
+ unsloth_grpo_mini_batch : Optional[int] = field(
418
+ default = None,
419
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
420
+ )
421
+ max_seq_length : Optional[int] = field(
422
+ default = None,
423
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
424
+ )
425
+ def __init__(
426
+ self,
427
+ output_dir = None,
428
+ per_device_train_batch_size = 4,
429
+ num_train_epochs = 3.0,
430
+ max_steps = -1,
431
+ learning_rate = 5e-05,
432
+ lr_scheduler_type = 'linear',
433
+ lr_scheduler_kwargs = None,
434
+ warmup_steps = 0.1,
435
+ optim = 'adamw_8bit',
436
+ optim_args = None,
437
+ weight_decay = 0.01,
438
+ adam_beta1 = 0.9,
439
+ adam_beta2 = 0.999,
440
+ adam_epsilon = 1e-08,
441
+ optim_target_modules = None,
442
+ gradient_accumulation_steps = 2,
443
+ average_tokens_across_devices = True,
444
+ max_grad_norm = 1.0,
445
+ label_smoothing_factor = 0.0,
446
+ bf16 = False,
447
+ fp16 = False,
448
+ bf16_full_eval = False,
449
+ fp16_full_eval = False,
450
+ tf32 = None,
451
+ gradient_checkpointing = True,
452
+ gradient_checkpointing_kwargs = None,
453
+ torch_compile = False,
454
+ torch_compile_backend = None,
455
+ torch_compile_mode = None,
456
+ use_liger_kernel = False,
457
+ liger_kernel_config = None,
458
+ use_cache = False,
459
+ neftune_noise_alpha = None,
460
+ torch_empty_cache_steps = 250,
461
+ auto_find_batch_size = False,
462
+ logging_strategy = 'steps',
463
+ logging_steps = 1,
464
+ logging_first_step = False,
465
+ log_on_each_node = True,
466
+ logging_nan_inf_filter = False,
467
+ include_num_input_tokens_seen = False,
468
+ log_level = 'passive',
469
+ log_level_replica = 'warning',
470
+ disable_tqdm = None,
471
+ report_to = 'none',
472
+ run_name = None,
473
+ project = 'huggingface',
474
+ trackio_space_id = 'trackio',
475
+ eval_strategy = 'no',
476
+ eval_steps = None,
477
+ eval_delay = 0,
478
+ per_device_eval_batch_size = 4,
479
+ prediction_loss_only = False,
480
+ eval_on_start = False,
481
+ eval_do_concat_batches = True,
482
+ eval_use_gather_object = False,
483
+ eval_accumulation_steps = 2,
484
+ batch_eval_metrics = False,
485
+ save_only_model = False,
486
+ save_strategy = 'steps',
487
+ save_steps = 500,
488
+ save_on_each_node = False,
489
+ save_total_limit = None,
490
+ enable_jit_checkpoint = False,
491
+ push_to_hub = False,
492
+ hub_token = None,
493
+ hub_private_repo = None,
494
+ hub_model_id = None,
495
+ hub_strategy = 'every_save',
496
+ hub_always_push = False,
497
+ hub_revision = None,
498
+ load_best_model_at_end = False,
499
+ metric_for_best_model = None,
500
+ greater_is_better = None,
501
+ ignore_data_skip = False,
502
+ restore_callback_states_from_checkpoint = False,
503
+ full_determinism = False,
504
+ seed = 3407,
505
+ data_seed = 3407,
506
+ use_cpu = False,
507
+ accelerator_config = None,
508
+ parallelism_config = None,
509
+ dataloader_drop_last = False,
510
+ dataloader_num_workers = 0,
511
+ dataloader_pin_memory = True,
512
+ dataloader_persistent_workers = False,
513
+ dataloader_prefetch_factor = None,
514
+ remove_unused_columns = True,
515
+ label_names = None,
516
+ train_sampling_strategy = 'random',
517
+ length_column_name = 'length',
518
+ ddp_find_unused_parameters = None,
519
+ ddp_bucket_cap_mb = None,
520
+ ddp_broadcast_buffers = None,
521
+ ddp_backend = None,
522
+ ddp_timeout = 1800,
523
+ fsdp = None,
524
+ fsdp_config = None,
525
+ deepspeed = None,
526
+ debug = '',
527
+ skip_memory_metrics = True,
528
+ do_train = False,
529
+ do_eval = False,
530
+ do_predict = False,
531
+ resume_from_checkpoint = None,
532
+ warmup_ratio = None,
533
+ logging_dir = None,
534
+ local_rank = -1,
535
+ model_init_kwargs = None,
536
+ chat_template_path = None,
537
+ dataset_text_field = 'text',
538
+ dataset_kwargs = None,
539
+ dataset_num_proc = None,
540
+ eos_token = None,
541
+ pad_token = None,
542
+ max_length = 1024,
543
+ packing = False,
544
+ packing_strategy = 'bfd',
545
+ padding_free = None,
546
+ pad_to_multiple_of = None,
547
+ eval_packing = None,
548
+ completion_only_loss = None,
549
+ assistant_only_loss = False,
550
+ loss_type = 'nll',
551
+ activation_offloading = False,
552
+ vllm_sampling_params = None,
553
+ unsloth_num_chunks = -1,
554
+ unsloth_logit_chunk_multiplier = None,
555
+ unsloth_grpo_mini_batch = None,
556
+ max_seq_length = None,
557
+ **kwargs,
558
+ ):
559
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
560
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
561
+ if num_train_epochs is None:
562
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
563
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
564
+ output_dir = 'unsloth_training_checkpoints'
565
+ save_strategy = 'no'
566
+ import multiprocessing as _mp
567
+ if _mp.get_start_method() != 'fork':
568
+ dataset_num_proc = None
569
+ elif dataset_num_proc is None:
570
+ import psutil
571
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
572
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
573
+ if memory_gb_left <= 2: dataset_num_proc = 1
574
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
575
+ if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
576
+ from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
577
+ if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
578
+ from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
579
+ pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
580
+
581
+
582
+ super().__init__(
583
+ output_dir = output_dir,
584
+ per_device_train_batch_size = per_device_train_batch_size,
585
+ num_train_epochs = num_train_epochs,
586
+ max_steps = max_steps,
587
+ learning_rate = learning_rate,
588
+ lr_scheduler_type = lr_scheduler_type,
589
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
590
+ warmup_steps = warmup_steps,
591
+ optim = optim,
592
+ optim_args = optim_args,
593
+ weight_decay = weight_decay,
594
+ adam_beta1 = adam_beta1,
595
+ adam_beta2 = adam_beta2,
596
+ adam_epsilon = adam_epsilon,
597
+ optim_target_modules = optim_target_modules,
598
+ gradient_accumulation_steps = gradient_accumulation_steps,
599
+ average_tokens_across_devices = average_tokens_across_devices,
600
+ max_grad_norm = max_grad_norm,
601
+ label_smoothing_factor = label_smoothing_factor,
602
+ bf16 = bf16,
603
+ fp16 = fp16,
604
+ bf16_full_eval = bf16_full_eval,
605
+ fp16_full_eval = fp16_full_eval,
606
+ tf32 = tf32,
607
+ gradient_checkpointing = gradient_checkpointing,
608
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
609
+ torch_compile = torch_compile,
610
+ torch_compile_backend = torch_compile_backend,
611
+ torch_compile_mode = torch_compile_mode,
612
+ use_liger_kernel = use_liger_kernel,
613
+ liger_kernel_config = liger_kernel_config,
614
+ use_cache = use_cache,
615
+ neftune_noise_alpha = neftune_noise_alpha,
616
+ torch_empty_cache_steps = torch_empty_cache_steps,
617
+ auto_find_batch_size = auto_find_batch_size,
618
+ logging_strategy = logging_strategy,
619
+ logging_steps = logging_steps,
620
+ logging_first_step = logging_first_step,
621
+ log_on_each_node = log_on_each_node,
622
+ logging_nan_inf_filter = logging_nan_inf_filter,
623
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
624
+ log_level = log_level,
625
+ log_level_replica = log_level_replica,
626
+ disable_tqdm = disable_tqdm,
627
+ report_to = report_to,
628
+ run_name = run_name,
629
+ project = project,
630
+ trackio_space_id = trackio_space_id,
631
+ eval_strategy = eval_strategy,
632
+ eval_steps = eval_steps,
633
+ eval_delay = eval_delay,
634
+ per_device_eval_batch_size = per_device_eval_batch_size,
635
+ prediction_loss_only = prediction_loss_only,
636
+ eval_on_start = eval_on_start,
637
+ eval_do_concat_batches = eval_do_concat_batches,
638
+ eval_use_gather_object = eval_use_gather_object,
639
+ eval_accumulation_steps = eval_accumulation_steps,
640
+ batch_eval_metrics = batch_eval_metrics,
641
+ save_only_model = save_only_model,
642
+ save_strategy = save_strategy,
643
+ save_steps = save_steps,
644
+ save_on_each_node = save_on_each_node,
645
+ save_total_limit = save_total_limit,
646
+ enable_jit_checkpoint = enable_jit_checkpoint,
647
+ push_to_hub = push_to_hub,
648
+ hub_token = hub_token,
649
+ hub_private_repo = hub_private_repo,
650
+ hub_model_id = hub_model_id,
651
+ hub_strategy = hub_strategy,
652
+ hub_always_push = hub_always_push,
653
+ hub_revision = hub_revision,
654
+ load_best_model_at_end = load_best_model_at_end,
655
+ metric_for_best_model = metric_for_best_model,
656
+ greater_is_better = greater_is_better,
657
+ ignore_data_skip = ignore_data_skip,
658
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
659
+ full_determinism = full_determinism,
660
+ seed = seed,
661
+ data_seed = data_seed,
662
+ use_cpu = use_cpu,
663
+ accelerator_config = accelerator_config,
664
+ parallelism_config = parallelism_config,
665
+ dataloader_drop_last = dataloader_drop_last,
666
+ dataloader_num_workers = dataloader_num_workers,
667
+ dataloader_pin_memory = dataloader_pin_memory,
668
+ dataloader_persistent_workers = dataloader_persistent_workers,
669
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
670
+ remove_unused_columns = remove_unused_columns,
671
+ label_names = label_names,
672
+ train_sampling_strategy = train_sampling_strategy,
673
+ length_column_name = length_column_name,
674
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
675
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
676
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
677
+ ddp_backend = ddp_backend,
678
+ ddp_timeout = ddp_timeout,
679
+ fsdp = fsdp,
680
+ fsdp_config = fsdp_config,
681
+ deepspeed = deepspeed,
682
+ debug = debug,
683
+ skip_memory_metrics = skip_memory_metrics,
684
+ do_train = do_train,
685
+ do_eval = do_eval,
686
+ do_predict = do_predict,
687
+ resume_from_checkpoint = resume_from_checkpoint,
688
+ warmup_ratio = warmup_ratio,
689
+ logging_dir = logging_dir,
690
+ local_rank = local_rank,
691
+ model_init_kwargs = model_init_kwargs,
692
+ chat_template_path = chat_template_path,
693
+ dataset_text_field = dataset_text_field,
694
+ dataset_kwargs = dataset_kwargs,
695
+ dataset_num_proc = dataset_num_proc,
696
+ eos_token = eos_token,
697
+ pad_token = pad_token,
698
+ max_length = max_length,
699
+ packing = packing,
700
+ packing_strategy = packing_strategy,
701
+ padding_free = padding_free,
702
+ pad_to_multiple_of = pad_to_multiple_of,
703
+ eval_packing = eval_packing,
704
+ completion_only_loss = completion_only_loss,
705
+ assistant_only_loss = assistant_only_loss,
706
+ loss_type = loss_type,
707
+ activation_offloading = activation_offloading,**kwargs)
708
+ self.vllm_sampling_params = vllm_sampling_params
709
+ self.unsloth_num_chunks = unsloth_num_chunks
710
+ if unsloth_grpo_mini_batch is not None:
711
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
712
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
713
+ else:
714
+ raise ValueError(
715
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
716
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
717
+ )
718
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
719
+ self.max_seq_length = max_seq_length
720
+
721
+ pass
722
+
723
+ class _UnslothSFTTrainer(BaseTrainer):
724
+ """"""
725
+
726
+ _tag_names = ["trl", "sft"]
727
+ _name = "SFT"
728
+
729
+ def __init__(
730
+ self,
731
+ model: Union[str, PreTrainedModel],
732
+ args: Optional[Union[SFTConfig, TrainingArguments]] = None,
733
+ data_collator: Optional[DataCollator] = None,
734
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
735
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
736
+ processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
737
+ compute_loss_func: Optional[Callable] = None,
738
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
739
+ callbacks: Optional[list[TrainerCallback]] = None,
740
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
741
+ optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
742
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
743
+ peft_config: Optional["PeftConfig"] = None,
744
+ formatting_func: Optional[Callable[[dict], str]] = None,
745
+ ):
746
+ # Args
747
+ if args is None:
748
+ model_name = model if isinstance(model, str) else model.config._name_or_path
749
+ model_name = model_name.split("/")[-1]
750
+ args = SFTConfig(f"{model_name}-SFT")
751
+ elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
752
+ dict_args = args.to_dict()
753
+ dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
754
+ dict_args.pop("push_to_hub_token", None)
755
+ args = SFTConfig(**dict_args)
756
+
757
+ # Model
758
+ if isinstance(model, str):
759
+ model = create_model_from_path(model, **args.model_init_kwargs or {})
760
+ else:
761
+ if args.model_init_kwargs is not None:
762
+ logger.warning(
763
+ "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. "
764
+ "The `model_init_kwargs` will be ignored."
765
+ )
766
+ model_id = model.config._name_or_path
767
+
768
+ # Processing class
769
+ if processing_class is None:
770
+ processing_class = AutoProcessor.from_pretrained(model_id)
771
+
772
+ # Handle pad token for processors or tokenizers
773
+ if isinstance(processing_class, ProcessorMixin):
774
+ tokenizer = processing_class.tokenizer
775
+ self._is_vlm = True
776
+ elif isinstance(processing_class, PreTrainedTokenizerBase):
777
+ tokenizer = processing_class
778
+ self._is_vlm = False
779
+ else:
780
+ raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
781
+
782
+ if args.eos_token is not None:
783
+ eos_token = args.eos_token
784
+ eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
785
+ if eos_token_id is None:
786
+ raise ValueError(
787
+ f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
788
+ f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
789
+ "in the vocabulary before using it as an EOS token."
790
+ )
791
+ tokenizer.eos_token_id = eos_token_id
792
+
793
+ if args.chat_template_path is not None:
794
+ if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
795
+ with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
796
+ processing_class.chat_template = chat_template_file.read()
797
+ added_tokens = []
798
+ else:
799
+ model, processing_class, added_tokens = clone_chat_template(
800
+ model, processing_class, args.chat_template_path
801
+ )
802
+ else:
803
+ added_tokens = []
804
+
805
+ # Catch some wrong configurations related to VLMs
806
+ if self._is_vlm and args.packing:
807
+ raise ValueError(
808
+ "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig."
809
+ )
810
+ if self._is_vlm and args.padding_free:
811
+ raise ValueError(
812
+ "Padding-free training is yet not supported for vision-language models. Please set "
813
+ "`padding_free=False` in the `SFTConfig`."
814
+ )
815
+ if self._is_vlm and args.assistant_only_loss:
816
+ raise ValueError(
817
+ "Assistant-only loss is not yet supported for vision-language models. Please set "
818
+ "`assistant_only_loss=False` in the `SFTConfig`."
819
+ )
820
+
821
+ # PEFT configuration and model wrapping
822
+ if False:
823
+ if added_tokens:
824
+ # Ensure that the added tokens are trainable
825
+ if peft_config.trainable_token_indices is None:
826
+ peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
827
+ elif "embed_tokens" not in peft_config.trainable_token_indices:
828
+ peft_config.trainable_token_indices["embed_tokens"] = added_tokens
829
+ else:
830
+ peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
831
+
832
+ # Ensure that the lm_head is trainable
833
+ if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
834
+ logger.warning(
835
+ "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
836
+ "`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
837
+ "tokens, leading to degraded generation quality. To fix this, add "
838
+ "`modules_to_save=['lm_head']` to your PEFT configuration."
839
+ )
840
+
841
+ if peft_config.modules_to_save is None:
842
+ peft_config.modules_to_save = ["lm_head"]
843
+ else:
844
+ peft_config.modules_to_save.append("lm_head")
845
+
846
+ # In Prompt Tuning a small set of trainable virtual tokens [continuous prompt embeddings] is prepended to the
847
+ # input. We store the number of these tokens so we can account for them correctly when calculating accuracy.
848
+ self.num_virtual_tokens = 0
849
+
850
+ if False:
851
+ pass
852
+ if model.active_adapter in model.peft_config:
853
+ peft_model_config = model.peft_config[model.active_adapter]
854
+ self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0)
855
+
856
+ # Data collator
857
+ # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing
858
+ # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask.
859
+ self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd")
860
+ use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
861
+ if self.padding_free:
862
+ if data_collator is not None:
863
+ raise ValueError("Passing a custom data collator is not supported when using padding-free.")
864
+ if args.packing and args.packing_strategy == "wrapped":
865
+ logger.warning(
866
+ "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not "
867
+ "recommended. Please refer to the documentation to understand why this is not recommended."
868
+ )
869
+ if not use_flash_attention:
870
+ logger.warning(
871
+ "Padding-free training is enabled, but the attention implementation is not set to a supported "
872
+ "flash attention variant. Padding-free training flattens batches into a single sequence, and only "
873
+ "the following implementations are known to reliably support this: "
874
+ f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to "
875
+ "unexpected behavior. To ensure compatibility, set `attn_implementation` in the model "
876
+ "configuration to one of these supported options or verify that your attention mechanism can "
877
+ "handle flattened sequences."
878
+ )
879
+ # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format
880
+ # is prompt-completion, and False if the dataset format is language modeling.
881
+ dataset_sample = next(iter(train_dataset))
882
+ if args.completion_only_loss is None:
883
+ self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample
884
+ else:
885
+ self.completion_only_loss = args.completion_only_loss
886
+
887
+ self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
888
+ # Unsloth: override _is_vlm for VLM models that pass a bare tokenizer
889
+ if not self._is_vlm and self._is_vision_dataset:
890
+ _m = model
891
+ if hasattr(_m, "model"): _m = _m.model
892
+ if hasattr(getattr(_m, "config", None), "vision_config") or\
893
+ _m.__class__.__name__.endswith("ForConditionalGeneration"):
894
+ self._is_vlm = True
895
+ if self._is_vision_dataset and not self._is_vlm:
896
+ raise ValueError(
897
+ "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
898
+ "model does not seem to be a vision-language model. Please check your model and dataset."
899
+ )
900
+
901
+ if data_collator is None and not self._is_vision_dataset:
902
+ # Get the pad token: if not provided, use the one from the processing class or the eos token
903
+ # if the processing class does not have a pad token.
904
+ pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
905
+ pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
906
+ if pad_token_id is None:
907
+ raise ValueError(
908
+ f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
909
+ f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
910
+ "in the vocabulary before using it as a padding token."
911
+ )
912
+ data_collator = DataCollatorForLanguageModeling(
913
+ pad_token_id=pad_token_id,
914
+ completion_only_loss=self.completion_only_loss,
915
+ padding_free=self.padding_free,
916
+ pad_to_multiple_of=args.pad_to_multiple_of,
917
+ )
918
+ elif data_collator is None and self._is_vision_dataset:
919
+ data_collator = DataCollatorForVisionLanguageModeling(
920
+ processor=processing_class,
921
+ max_length=args.max_length,
922
+ completion_only_loss=self.completion_only_loss,
923
+ pad_to_multiple_of=args.pad_to_multiple_of,
924
+ dataset_text_field=args.dataset_text_field,
925
+ )
926
+
927
+ if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
928
+ logger.warning(
929
+ "You are using packing, but the attention implementation is not set to a supported flash attention "
930
+ "variant. Packing gathers multiple samples into a single sequence, and only the following "
931
+ f"implementations are known to reliably support this: {', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. "
932
+ "Using other implementations may lead to cross-contamination between samples. To avoid this, either "
933
+ "disable packing by setting `packing=False`, or set `attn_implementation` in the model configuration "
934
+ "to one of these supported options."
935
+ )
936
+ if args.assistant_only_loss and not is_conversational(dataset_sample):
937
+ raise ValueError(
938
+ "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only "
939
+ "supported for conversational datasets."
940
+ )
941
+
942
+ # Dataset
943
+ # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
944
+ # preprocessing [e.g., image-to-pixel conversion] is too costly and done on the fly instead.
945
+ skip_prepare_dataset = (
946
+ args.dataset_kwargs is not None
947
+ and args.dataset_kwargs.get("skip_prepare_dataset", False)
948
+ or self._is_vision_dataset
949
+ )
950
+ if not skip_prepare_dataset:
951
+ if self.completion_only_loss and formatting_func:
952
+ raise ValueError(
953
+ "A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
954
+ "Using a formatter converts the dataset to a language modeling type, conflicting with "
955
+ "completion-only loss. To resolve this, apply your formatting function before passing the "
956
+ "dataset, or disable `completion_only_loss` in `SFTConfig`."
957
+ )
958
+ self._unsloth_model_ref = model
959
+ train_dataset = self._prepare_dataset(
960
+ train_dataset, processing_class, args, args.packing, formatting_func, "train"
961
+ )
962
+ if eval_dataset is not None:
963
+ packing = args.packing if args.eval_packing is None else args.eval_packing
964
+ if isinstance(eval_dataset, dict):
965
+ eval_dataset = {
966
+ key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
967
+ for key, dataset in eval_dataset.items()
968
+ }
969
+ else:
970
+ eval_dataset = self._prepare_dataset(
971
+ eval_dataset, processing_class, args, packing, formatting_func, "eval"
972
+ )
973
+
974
+ # Loss function
975
+ if args.loss_type == "nll":
976
+ pass # use the default loss
977
+ elif args.loss_type == "dft":
978
+ if compute_loss_func is not None:
979
+ raise ValueError(
980
+ "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
981
+ "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a "
982
+ "`compute_loss_func` is not allowed."
983
+ )
984
+ compute_loss_func = dft_loss
985
+ else:
986
+ raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")
987
+
988
+ # Initialize the metrics
989
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
990
+ self._total_train_tokens = 0
991
+
992
+ # Initialize the Trainer. Parent class will handle:
993
+ # - DeepSpeed configuration [through create_accelerator_and_postprocess]
994
+ # - FSDP setup
995
+ # - Distributed training setup
996
+ # - Optimizer and scheduler creation
997
+
998
+ super().__init__(
999
+ model=model,
1000
+ args=args,
1001
+ data_collator=data_collator,
1002
+ train_dataset=train_dataset,
1003
+ eval_dataset=eval_dataset,
1004
+ processing_class=processing_class,
1005
+ compute_loss_func=compute_loss_func,
1006
+ compute_metrics=compute_metrics,
1007
+ callbacks=callbacks,
1008
+ optimizers=optimizers,
1009
+ optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
1010
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
1011
+ )
1012
+
1013
+ # Initialize activation offloading context
1014
+ if self.args.activation_offloading:
1015
+ self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
1016
+ else:
1017
+ self.maybe_activation_offload_context = contextlib.nullcontext()
1018
+
1019
+ # Add tags for models that have been loaded with the correct transformers version
1020
+ if hasattr(self.model, "add_model_tags"):
1021
+ self.model.add_model_tags(self._tag_names)
1022
+
1023
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
1024
+
1025
+ def _prepare_dataset(
1026
+ self,
1027
+ dataset: Union[Dataset, IterableDataset],
1028
+ processing_class,
1029
+ args,
1030
+ packing: bool,
1031
+ formatting_func: Optional[Callable[[dict], str]],
1032
+ dataset_name: str,
1033
+ ) -> Union[Dataset, IterableDataset]:
1034
+ # All Unsloth Zoo code licensed under LGPLv3
1035
+ try:
1036
+ if isinstance(dataset, ConstantLengthDataset): return dataset
1037
+ except:
1038
+ pass
1039
+
1040
+ map_kwargs = {}
1041
+ use_desc = isinstance(dataset, Dataset)
1042
+ is_vlm = hasattr(processing_class, "tokenizer")
1043
+ tokenizer = processing_class
1044
+ if is_vlm: tokenizer = processing_class.tokenizer
1045
+
1046
+ # Dynamic detection: check if model's module defines a function
1047
+ # that requires token_type_ids when is_training=True
1048
+ import sys as _sys
1049
+ _needs_token_type_ids = False
1050
+ # Split to avoid compiler substring match on masking_utils names
1051
+ _ccm = 'create_' + 'causal_mask_mapping'
1052
+ _model = getattr(self, '_unsloth_model_ref', None) or getattr(self, 'model', None)
1053
+ if _model is not None:
1054
+ for _m in (_model, getattr(_model, 'model', None)):
1055
+ if _m is None: continue
1056
+ _mod = _sys.modules.get(type(_m).__module__)
1057
+ if _mod is not None and hasattr(_mod, _ccm):
1058
+ _needs_token_type_ids = True
1059
+ break
1060
+
1061
+ if not _needs_token_type_ids:
1062
+ # Fallback: model not yet available, check processor class MRO
1063
+ for _base in type(processing_class).__mro__:
1064
+ _base_mod = getattr(_base, '__module__', '')
1065
+ if 'transformers.models.' in _base_mod:
1066
+ _modeling_mod = _base_mod.replace('.processing_', '.modeling_')
1067
+ _mod = _sys.modules.get(_modeling_mod)
1068
+ if _mod is not None and hasattr(_mod, _ccm):
1069
+ _needs_token_type_ids = True
1070
+ break
1071
+ if _needs_token_type_ids and hasattr(args, 'remove_unused_columns'):
1072
+ args.remove_unused_columns = False
1073
+
1074
+ # Get max length
1075
+ max_seq_length = getattr(args, "max_length", 0)
1076
+ if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
1077
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
1078
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
1079
+ if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
1080
+ dataset_text_field = getattr(args, "dataset_text_field", "text")
1081
+ do_truncation = max_seq_length != 0
1082
+ do_formatting_func = False
1083
+ do_tokenize = True
1084
+
1085
+ # Get correct column names
1086
+ column_names = set(next(iter(dataset)).keys())
1087
+ used_column_names = ["input_ids"]
1088
+ if "attention_mask" in column_names:
1089
+ used_column_names.append("attention_mask")
1090
+ if _needs_token_type_ids:
1091
+ used_column_names.append("token_type_ids")
1092
+
1093
+ # Check if already tokenized so skip
1094
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
1095
+ if "labels" in column_names:
1096
+ # Most likely forgot data collator!
1097
+ if is_vlm and not hasattr(tokenizer, "pad"):
1098
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
1099
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
1100
+ self.data_collator = DataCollatorForSeq2Seq(tokenizer)
1101
+ used_column_names.append("labels")
1102
+ do_tokenize = False
1103
+ elif "input_ids" in column_names:
1104
+ # Skip dataset prep, and set data collator
1105
+ if is_vlm and not hasattr(tokenizer, "pad"):
1106
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
1107
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
1108
+ self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
1109
+ do_tokenize = False
1110
+ elif dataset_text_field not in column_names:
1111
+ do_formatting_func = True
1112
+ if formatting_func is None:
1113
+ raise RuntimeError("Unsloth: You must specify a `formatting_func`")
1114
+ pass
1115
+
1116
+ if do_tokenize:
1117
+ # Check double BOS tokens
1118
+ if do_formatting_func:
1119
+ test_text = formatting_func(next(iter(dataset)))
1120
+ if not isinstance(test_text, list):
1121
+ raise ValueError(
1122
+ "Unsloth: The `formatting_func` should return a list of processed strings."
1123
+ )
1124
+ test_text = test_text[0]
1125
+ else:
1126
+ test_text = next(iter(dataset))[dataset_text_field][0]
1127
+
1128
+ # Get chat template
1129
+ chat_template = getattr(processing_class, 'chat_template', '')
1130
+ if chat_template == '' and is_vlm:
1131
+ chat_template = getattr(tokenizer, 'chat_template', '')
1132
+ if chat_template is None:
1133
+ chat_template = ''
1134
+
1135
+ # Get bos_token
1136
+ add_special_tokens = True
1137
+ bos_token_1 = getattr(processing_class, 'bos_token', None)
1138
+ bos_token_2 = getattr(tokenizer, 'bos_token', None)
1139
+ bos_token = bos_token_1 or bos_token_2
1140
+
1141
+ if bos_token is not None:
1142
+ if test_text.startswith(bos_token) or bos_token in chat_template:
1143
+ add_special_tokens = False
1144
+ print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
1145
+ pass
1146
+
1147
+ # Create tokenize function
1148
+ def _tokenize(example):
1149
+ return tokenizer(
1150
+ example[dataset_text_field] if not do_formatting_func else formatting_func(example),
1151
+ truncation = do_truncation,
1152
+ max_length = max_seq_length,
1153
+ return_token_type_ids = _needs_token_type_ids,
1154
+ add_special_tokens = add_special_tokens,
1155
+ )
1156
+ pass
1157
+
1158
+ if not isinstance(dataset, IterableDataset):
1159
+ import multiprocessing as _mp
1160
+ if _mp.get_start_method() != 'fork':
1161
+ dataset_num_proc = None
1162
+ else:
1163
+ dataset_num_proc = getattr(args, "dataset_num_proc", None)
1164
+ if dataset_num_proc is None:
1165
+ import psutil
1166
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
1167
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
1168
+ if memory_gb_left <= 2:
1169
+ dataset_num_proc = 1
1170
+ else:
1171
+ dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
1172
+ map_kwargs["num_proc"] = dataset_num_proc
1173
+ else:
1174
+ map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
1175
+
1176
+ if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
1177
+ import warnings as _w
1178
+ with _w.catch_warnings():
1179
+ _w.filterwarnings("ignore", message=".*couldn't be hashed properly.*")
1180
+ dataset = dataset.map(_tokenize, batched = True, remove_columns = list(column_names), **map_kwargs)
1181
+
1182
+ # If VLM, switch data collator since .pad is needed!
1183
+ if is_vlm and not hasattr(processing_class, "pad"):
1184
+ data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
1185
+ self.data_collator = data_collator
1186
+ pass
1187
+ pass
1188
+ if packing:
1189
+ # Try using new packing which works in TRL
1190
+ try:
1191
+ pack_dataset
1192
+ except:
1193
+ print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
1194
+ return dataset
1195
+
1196
+ if max_seq_length == 0:
1197
+ raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
1198
+
1199
+ if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
1200
+ dataset = pack_dataset(
1201
+ dataset.select_columns(used_column_names),
1202
+ max_seq_length,
1203
+ getattr(args, "packing_strategy", "bfd"),
1204
+ map_kwargs,
1205
+ )
1206
+ pass
1207
+ return dataset
1208
+
1209
+ def _set_signature_columns_if_needed(self):
1210
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
1211
+ # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
1212
+ # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
1213
+ # dataset. So we need to override the default signature columns to include "completion_mask" as well.
1214
+ if self._signature_columns is None:
1215
+ if self._is_vision_dataset:
1216
+ self._signature_columns = ["messages", "prompt", "completion", "images", "input_ids", "labels", "attention_mask", "seq_lengths", "completion_mask", "assistant_masks"]
1217
+ else:
1218
+ self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]
1219
+
1220
+ def compute_loss(
1221
+ self, model, inputs, return_outputs = False, num_items_in_batch = None
1222
+ ):
1223
+ outputs = super().compute_loss(
1224
+ model,
1225
+ inputs,
1226
+ return_outputs = return_outputs,
1227
+ num_items_in_batch = num_items_in_batch,
1228
+ )
1229
+ return outputs
1230
+
1231
+ # Override training step to add activation offloading context.
1232
+ def training_step(self, *args, **kwargs):
1233
+ with self.maybe_activation_offload_context:
1234
+ return super().training_step(*args, **kwargs)
1235
+
1236
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1237
+ mode = "train" if self.model.training else "eval"
1238
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
1239
+
1240
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1241
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1242
+ if mode == "eval":
1243
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1244
+
1245
+ logs.update(metrics)
1246
+ super().log(logs, start_time)
1247
+ self._metrics[mode].clear()
1248
+
1249
+ # Ensure the model card is saved along with the checkpoint
1250
+ def _save_checkpoint(self, model, trial):
1251
+ if self.args.hub_model_id is None:
1252
+ model_name = Path(self.args.output_dir).name
1253
+ else:
1254
+ model_name = self.args.hub_model_id.split("/")[-1]
1255
+ self.create_model_card(model_name=model_name)
1256
+ super()._save_checkpoint(model, trial)
1257
+ class UnslothSFTTrainer(_UnslothSFTTrainer):
1258
+ """
1259
+
1260
+ Trainer for Supervised Fine-Tuning (SFT) method.
1261
+
1262
+ This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
1263
+
1264
+ Example:
1265
+
1266
+ ```python
1267
+ from datasets import load_dataset
1268
+ from trl import SFTTrainer
1269
+
1270
+ dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
1271
+
1272
+ trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
1273
+ trainer.train()
1274
+ ```
1275
+
1276
+ Args:
1277
+ model (`Union[str, PreTrainedModel]`):
1278
+ Model to be trained. Can be either:
1279
+
1280
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
1281
+ path to a *directory* containing model weights saved using
1282
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1283
+ using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
1284
+ config) with the keyword arguments in `args.model_init_kwargs`.
1285
+ - A [`~transformers.PreTrainedModel`] object.
1286
+ If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss
1287
+ as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`.
1288
+ args ([`SFTConfig`], *optional*):
1289
+ Configuration for this trainer. If `None`, a default configuration is used.
1290
+ data_collator ([`~transformers.DataCollator`], *optional*):
1291
+ Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
1292
+ Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model
1293
+ and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
1294
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1295
+ Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
1296
+ [prompt-completion](#prompt-completion) type. The format of the samples can be either:
1297
+
1298
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
1299
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1300
+ and content).
1301
+
1302
+ The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
1303
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1304
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1305
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
1306
+ Processing class used to process the data. If `None`, the processing class is loaded from the model's name
1307
+ with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set.
1308
+ If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.
1309
+ compute_loss_func (`Callable`, *optional*):
1310
+ A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
1311
+ batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss
1312
+ function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618)
1313
+ used by [`Trainer`].
1314
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1315
+ The function that will be used to compute metrics at evaluation. Must take a
1316
+ [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
1317
+ [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean
1318
+ `compute_result` argument. This will be triggered after the last eval batch to signal that the function
1319
+ needs to calculate and return the global summary statistics rather than accumulating the batch-level
1320
+ statistics.
1321
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*):
1322
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
1323
+ in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1324
+
1325
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1326
+ method.
1327
+ optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
1328
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
1329
+ model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
1330
+ optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
1331
+ A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
1332
+ `args`. Incompatible with the `optimizers` argument.
1333
+
1334
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
1335
+ initializing the Trainer.
1336
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
1337
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
1338
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
1339
+ by this function will be reflected in the predictions received by `compute_metrics`.
1340
+
1341
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
1342
+ peft_config ([`~peft.PeftConfig`], *optional*):
1343
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
1344
+ formatting_func (`Callable`, *optional*):
1345
+ Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly
1346
+ converts the dataset into a [language modeling](#language-modeling) type.
1347
+
1348
+ """
1349
+ def __init__(
1350
+ self,
1351
+ model,
1352
+ args = None,
1353
+ data_collator = None,
1354
+ train_dataset = None,
1355
+ eval_dataset = None,
1356
+ processing_class = None,
1357
+ compute_loss_func = None,
1358
+ compute_metrics = None,
1359
+ callbacks = None,
1360
+ optimizer_cls_and_kwargs = None,
1361
+ preprocess_logits_for_metrics = None,
1362
+ peft_config = None,
1363
+ formatting_func = None,
1364
+ **kwargs
1365
+ ):
1366
+ if args is None: args = UnslothSFTConfig()
1367
+ use_bf16 = getattr(args, 'bf16', False)
1368
+ if type(use_bf16) is not bool: use_bf16 = False
1369
+ use_fp16 = getattr(args, 'fp16', False)
1370
+ if type(use_fp16) is not bool: use_fp16 = False
1371
+ force_float32 = False
1372
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1373
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1374
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1375
+ force_float32 = True
1376
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1377
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1378
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1379
+ from unsloth_zoo.utils import _get_dtype
1380
+ dtype = _get_dtype(dtype)
1381
+ float16 = dtype == torch.float16
1382
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1383
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1384
+ if force_float32:
1385
+ # Forced float32 training
1386
+ args.fp16 = False
1387
+ args.bf16 = False
1388
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1389
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1390
+ # args.mixed_precision is a new argument which needs to be set now
1391
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1392
+ # Mixed precision training
1393
+ args.fp16 = float16
1394
+ args.bf16 = not float16
1395
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1396
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1397
+ # args.mixed_precision is a new argument which needs to be set now
1398
+ elif mixed_precision_dtype == 'bfloat16':
1399
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1400
+ args.fp16 = False
1401
+ args.bf16 = False
1402
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1403
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1404
+ # args.mixed_precision is a new argument which needs to be set now
1405
+
1406
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1407
+ args.eval_strategy = 'steps'
1408
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1409
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1410
+ if ga_steps is not None and ga_steps > 1:
1411
+ from transformers import __version__ as transformers_version
1412
+ if Version(transformers_version) <= Version('4.45.2'):
1413
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1414
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1415
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1416
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1417
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1418
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1419
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1420
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1421
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1422
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1423
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1424
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1425
+ if force_float32:
1426
+ args.bf16_full_eval = False
1427
+ args.fp16_full_eval = False
1428
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1429
+ args.bf16_full_eval = True
1430
+ args.fp16_full_eval = False
1431
+ elif not bf16_full_eval and not fp16_full_eval:
1432
+ args.bf16_full_eval = args.bf16
1433
+ args.fp16_full_eval = args.fp16
1434
+ _output_logits = False
1435
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1436
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1437
+ if _output_logits:
1438
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1439
+ if model is not None:
1440
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1441
+ if _warnings_issued is None:
1442
+ model.warnings_issued = {}
1443
+ elif not isinstance(_warnings_issued, dict):
1444
+ try:
1445
+ model.warnings_issued = dict(_warnings_issued)
1446
+ except Exception:
1447
+ model.warnings_issued = {}
1448
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1449
+ pass
1450
+ else:
1451
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1452
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1453
+ if args_max_seq_length is None and model_max_seq_length is not None:
1454
+ max_seq_length = model.max_seq_length
1455
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1456
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1457
+ if args_max_seq_length > model_max_seq_length:
1458
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1459
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1460
+ args.max_seq_length = model_max_seq_length
1461
+ if 'max_length' not in locals() and not hasattr(args, 'max_length'):
1462
+ pass
1463
+ else:
1464
+ if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:
1465
+ if hasattr(args, 'max_length'):
1466
+ args.max_length = args.max_seq_length
1467
+ max_length = args.max_length
1468
+ else:
1469
+ model_max_length = getattr(model, 'max_seq_length', None)
1470
+ if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
1471
+ if model_max_length is not None:
1472
+ args.max_length = model_max_length
1473
+ max_length = args.max_length
1474
+ elif hasattr(args, 'max_length') and args.max_length is not None:
1475
+ max_length = args.max_length
1476
+ # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set
1477
+ setattr(model, 'max_seq_length', max_length)
1478
+ else:
1479
+ print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')
1480
+ args.max_length = 1024
1481
+ if model is not None and hasattr(model, 'for_training'):
1482
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1483
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1484
+ if 'processing_class' in locals():
1485
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1486
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1487
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1488
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1489
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1490
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1491
+ data_collator = TransformersDataCollatorForLanguageModeling(
1492
+ __tokenizer,
1493
+ mlm = False,
1494
+ mlm_probability = 0.0,
1495
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1496
+ )
1497
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1498
+ data_collator = DataCollatorForSeq2Seq(
1499
+ __tokenizer,
1500
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1501
+ )
1502
+ else:
1503
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1504
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1505
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1506
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1507
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1508
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1509
+ data_collator = DataCollatorForSeq2Seq(
1510
+ __tokenizer.tokenizer,
1511
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1512
+ )
1513
+ else:
1514
+ data_collator = TransformersDataCollatorForLanguageModeling(
1515
+ __tokenizer.tokenizer,
1516
+ mlm = False,
1517
+ mlm_probability = 0.0,
1518
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1519
+ )
1520
+ other_metrics = []
1521
+
1522
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1523
+ PatchRLStatistics('sft_trainer', other_metrics)
1524
+ IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
1525
+ from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
1526
+ from unsloth_zoo.training_utils import fix_zero_training_loss
1527
+ if 'tokenizer' not in locals(): tokenizer = processing_class
1528
+ fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
1529
+ fix_zero_training_loss(model, tokenizer, train_dataset)
1530
+
1531
+ # [TODO] Fix up DataParallel multiplying batch sizes
1532
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1533
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1534
+ if getattr(args, "_n_gpu", 1) != 1:
1535
+ args._n_gpu = 1
1536
+ if "model" in locals() and hasattr(model, "for_training"):
1537
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1538
+ super().__init__(
1539
+ model = model,
1540
+ args = args,
1541
+ data_collator = data_collator,
1542
+ train_dataset = train_dataset,
1543
+ eval_dataset = eval_dataset,
1544
+ processing_class = processing_class,
1545
+ compute_loss_func = compute_loss_func,
1546
+ compute_metrics = compute_metrics,
1547
+ callbacks = callbacks,
1548
+ optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1549
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1550
+ peft_config = peft_config,
1551
+ formatting_func = formatting_func,**kwargs)
1552
+ if "model" in locals() and hasattr(model, "for_inference"):
1553
+ model.for_inference()
1554
+ if hasattr(self, 'neftune_hook_handle'):
1555
+ self.neftune_hook_handle.remove()
1556
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1557
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1558
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1559
+ pass
1560
+ if hasattr(self, 'accelerator'):
1561
+ scaler = self.accelerator.scaler
1562
+ current_model = model
1563
+ while hasattr(current_model, 'model'):
1564
+ current_model.accelerator_scaler = scaler
1565
+ current_model = current_model.model
1566
+ current_model.accelerator_scaler = scaler
1567
+ pass
1568
+ if hasattr(self, 'train'):
1569
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1570
+ pass
1571
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1572
+ _vllm_tok = self.llm.get_tokenizer()
1573
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1574
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1575
+ _vllm_tok.chat_template = _pc.chat_template
1576
+ pass
1577
+
1578
+ pass
1579
+
1580
+
1581
+ if hasattr(logger, "addFilter"):
1582
+ import logging
1583
+ class HideLoggingMessage(logging.Filter):
1584
+ def __init__(self, text): self.text = text
1585
+ def filter(self, x): return not (self.text in x.getMessage())
1586
+ pass
1587
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
1588
+
unsloth_compiled_cache/moe_utils.py ADDED
@@ -0,0 +1,1320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unsloth Zoo - Utilities for Unsloth
2
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
3
+ #
4
+ # This program is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU Affero General Public License as published
6
+ # by the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # This program is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU Affero General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU Affero General Public License
15
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import os
19
+ import shutil
20
+ import sys
21
+ import importlib.util
22
+ from typing import Optional, Tuple
23
+ from torch.autograd import Function
24
+
25
+ # Get compile location
26
+ UNSLOTH_COMPILE_LOCATION = os.environ.get(
27
+ "UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache"
28
+ )
29
+
30
+
31
+ def _get_compile_location() -> str:
32
+ return os.path.abspath(
33
+ os.environ.get("UNSLOTH_COMPILE_LOCATION", UNSLOTH_COMPILE_LOCATION)
34
+ )
35
+
36
+
37
+ def _log_info(message: str):
38
+ if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1":
39
+ print(message)
40
+
41
+
42
+ def install_to_cache(source_path, destination_filename=None):
43
+ """
44
+ Copies a file to the unsloth_compiled_cache directory
45
+ to ensure it is available for compiled modules.
46
+ """
47
+ compile_location = _get_compile_location()
48
+ if not os.path.exists(compile_location):
49
+ try:
50
+ os.makedirs(compile_location)
51
+ except:
52
+ pass
53
+
54
+ current_file = os.path.abspath(source_path)
55
+ if destination_filename is None:
56
+ destination_filename = os.path.basename(current_file)
57
+
58
+ destination = os.path.abspath(os.path.join(compile_location, destination_filename))
59
+
60
+ # If source and dest are different, copy.
61
+ if current_file != destination:
62
+ try:
63
+ shutil.copy(current_file, destination)
64
+ except Exception:
65
+ pass
66
+
67
+
68
+ install_to_cache(__file__, "moe_utils.py")
69
+
70
+ _CACHED_FORWARD_MOE_BACKEND = None
71
+ _CACHED_MOE_UTILS_MODULE = None
72
+
73
+
74
+ def _load_cached_moe_utils_module():
75
+ global _CACHED_MOE_UTILS_MODULE
76
+
77
+ cache_file = os.path.abspath(os.path.join(_get_compile_location(), "moe_utils.py"))
78
+ current_file = os.path.abspath(__file__)
79
+ if not os.path.isfile(cache_file) or cache_file == current_file:
80
+ return None
81
+
82
+ try:
83
+ module_name = "unsloth_cached_moe_utils"
84
+ module = sys.modules.get(module_name, None)
85
+ if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
86
+ _CACHED_MOE_UTILS_MODULE = module
87
+ return module
88
+
89
+ spec = importlib.util.spec_from_file_location(module_name, cache_file)
90
+ if spec is None or spec.loader is None:
91
+ return None
92
+ module = importlib.util.module_from_spec(spec)
93
+ sys.modules[module_name] = module
94
+ spec.loader.exec_module(module)
95
+ _CACHED_MOE_UTILS_MODULE = module
96
+ return module
97
+ except Exception:
98
+ return None
99
+
100
+
101
+ def get_forward_moe_backend():
102
+ """
103
+ Resolve forward_moe_backend from the compiled cache copy when available.
104
+ Falls back to the local module definition.
105
+ """
106
+ global _CACHED_FORWARD_MOE_BACKEND
107
+ module = _load_cached_moe_utils_module()
108
+ if module is not None and hasattr(module, "forward_moe_backend"):
109
+ _CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend
110
+ return _CACHED_FORWARD_MOE_BACKEND
111
+
112
+ _CACHED_FORWARD_MOE_BACKEND = forward_moe_backend
113
+ return _CACHED_FORWARD_MOE_BACKEND
114
+
115
+ # ============================================================================
116
+ # Grouped MM wrapper
117
+ # ============================================================================
118
+ # Simple wrapper around torch._grouped_mm that ensures contiguous inputs.
119
+ # Native backward works correctly - no custom autograd needed.
120
+ # ============================================================================
121
+
122
+
123
+ def _grouped_mm_with_backward_fix(
124
+ inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
125
+ ) -> torch.Tensor:
126
+ """
127
+ Grouped matmul with working backward pass.
128
+
129
+ Uses native torch._grouped_mm with contiguous inputs for correct gradients.
130
+ """
131
+ return torch._grouped_mm(inputs, weight, offs=offsets)
132
+
133
+
134
+ # Global flag to check if grouped GEMM is available
135
+ _GROUPED_GEMM_AVAILABLE = None
136
+ _TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm")
137
+
138
+ # Check if GPU supports torch._grouped_mm (verified via runtime check)
139
+ _TORCH_GROUPED_MM_SUPPORTED = None
140
+
141
+
142
+ def _check_torch_grouped_mm_supported():
143
+ """
144
+ Check if torch._grouped_mm is actually supported on the current GPU.
145
+ We check for existence and verify with a dummy call.
146
+ A runtime probe is the only reliable check.
147
+ """
148
+ global _TORCH_GROUPED_MM_SUPPORTED
149
+ if _TORCH_GROUPED_MM_SUPPORTED is not None: return _TORCH_GROUPED_MM_SUPPORTED
150
+
151
+ if not _TORCH_GROUPED_MM_AVAILABLE:
152
+ _TORCH_GROUPED_MM_SUPPORTED = False
153
+ return False
154
+
155
+ if not torch.cuda.is_available():
156
+ _TORCH_GROUPED_MM_SUPPORTED = False
157
+ return False
158
+
159
+ try:
160
+ # Attempt a dummy grouped_mm call to verify support.
161
+ # This handles cases where the symbol exists but hardware is unsupported (e.g. < H100).
162
+ # It also allows support on newer hardware or backports without code changes.
163
+ device = torch.cuda.current_device()
164
+ dtype = torch.float16
165
+
166
+ # Minimal dummy data: 1 expert, 1 token, dim 8 (safe alignment)
167
+ x = torch.ones((1, 8), device=device, dtype=dtype)
168
+ w = torch.ones((1, 8, 8), device=device, dtype=dtype)
169
+ offs = torch.tensor([1], device=device, dtype=torch.int32)
170
+
171
+ torch._grouped_mm(x, w, offs=offs)
172
+ del x, w, offs
173
+ _TORCH_GROUPED_MM_SUPPORTED = True
174
+ except Exception:
175
+ _TORCH_GROUPED_MM_SUPPORTED = False
176
+
177
+ return _TORCH_GROUPED_MM_SUPPORTED
178
+
179
+
180
+ _TRITON_ALLOCATOR_INITIALIZED = False
181
+ _PERSISTENT_BUFFER = None
182
+
183
+
184
+ def _init_triton_allocator():
185
+ """
186
+ Initialize a persistent Triton allocator to avoid memory allocation overhead per call.
187
+ This significantly reduces GPU utilization fluctuation.
188
+ """
189
+ global _TRITON_ALLOCATOR_INITIALIZED, _PERSISTENT_BUFFER
190
+ if _TRITON_ALLOCATOR_INITIALIZED: return
191
+
192
+ try:
193
+ import triton
194
+
195
+ # Create a persistent buffer that grows as needed
196
+ # This avoids allocating new memory on every kernel call
197
+
198
+ def persistent_alloc_fn(size: int, alignment: int, stream):
199
+ global _PERSISTENT_BUFFER
200
+ # Round up size to avoid frequent reallocations
201
+ # Round to nearest 128 bytes for alignment
202
+ rounded_size = ((size + 128 - 1) // 128) * 128
203
+
204
+ if (
205
+ _PERSISTENT_BUFFER is None
206
+ or _PERSISTENT_BUFFER.numel() * _PERSISTENT_BUFFER.element_size()
207
+ < rounded_size
208
+ ):
209
+ # Allocate with small headroom (10%) to reduce reallocations
210
+ # Use ByteTensor (uint8) for raw byte storage
211
+ _PERSISTENT_BUFFER = torch.empty(
212
+ int(rounded_size * 1.1), device="cuda", dtype=torch.uint8
213
+ )
214
+ _PERSISTENT_BUFFER.__hibernate__ = {"type": "ignore"}
215
+ return _PERSISTENT_BUFFER
216
+
217
+ triton.set_allocator(persistent_alloc_fn)
218
+ triton._unsloth_allocator_set = True
219
+ _TRITON_ALLOCATOR_INITIALIZED = True
220
+ except Exception:
221
+ pass
222
+
223
+
224
+ def _check_grouped_gemm_available():
225
+ """Check if Unsloth grouped GEMM kernels are available."""
226
+ if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False
227
+
228
+ global _GROUPED_GEMM_AVAILABLE
229
+ if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE
230
+
231
+ try:
232
+ from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm, supports_tma
233
+ _GROUPED_GEMM_AVAILABLE = True
234
+ _init_triton_allocator()
235
+ except (ImportError, ModuleNotFoundError):
236
+ _GROUPED_GEMM_AVAILABLE = False
237
+ return _GROUPED_GEMM_AVAILABLE
238
+
239
+
240
+ from functools import lru_cache
241
+
242
+
243
+ @lru_cache(maxsize=1)
244
+ def select_moe_backend():
245
+ """
246
+ Selects the MoE backend based on UNSLOTH_MOE_BACKEND environment variable and availability.
247
+ Choices: "grouped_mm", "unsloth_triton", "native_torch".
248
+ Default if unspecified: "grouped_mm".
249
+ """
250
+ # This Unsloth Zoo code section is licensed under AGPL3
251
+
252
+ requested = os.environ.get("UNSLOTH_MOE_BACKEND")
253
+ if requested:
254
+ if requested == "grouped_mm" and _check_torch_grouped_mm_supported():
255
+ return "grouped_mm"
256
+ if requested == "unsloth_triton" and _check_grouped_gemm_available():
257
+ return "unsloth_triton"
258
+ if requested == "native_torch":
259
+ return "native_torch"
260
+ _log_info(f"Unsloth: '{requested}' backend requested but is not available. Falling back to next available.")
261
+
262
+ if _check_torch_grouped_mm_supported():
263
+ _log_info("Unsloth: Using MoE backend 'grouped_mm'")
264
+ return "grouped_mm"
265
+ if _check_grouped_gemm_available():
266
+ _log_info("Unsloth: Using MoE backend 'unsloth_triton'")
267
+ return "unsloth_triton"
268
+ return "native_torch"
269
+
270
+
271
+ def forward_moe_backend(
272
+ self,
273
+ hidden_states: torch.Tensor,
274
+ top_k_index: torch.Tensor,
275
+ top_k_weights: torch.Tensor,
276
+ ) -> torch.Tensor:
277
+ """
278
+ Dispatch MoE forward to the selected backend.
279
+ Centralizes backend selection to keep model-specific patches minimal.
280
+ """
281
+ # This Unsloth Zoo code section is licensed under AGPL3
282
+
283
+ backend = select_moe_backend()
284
+ if backend == "grouped_mm":
285
+ return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights)
286
+ if backend == "unsloth_triton":
287
+ return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights)
288
+ return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights)
289
+
290
+
291
+ @torch.no_grad()
292
+ def _get_routing_indices(selected_experts, num_experts):
293
+ """
294
+ Compute token→expert mapping for grouped GEMM.
295
+ Uses bincount instead of histc to avoid float conversion overhead.
296
+
297
+ Returns:
298
+ token_counts_by_expert: (num_experts,) token counts per expert
299
+ gather_indices: (total_tokens,) indices for gathering tokens in expert order
300
+ """
301
+ # This Unsloth Zoo code section is licensed under AGPL3
302
+
303
+ flat_experts = selected_experts.view(-1)
304
+
305
+ # bincount is faster than histc since it doesn't require float conversion
306
+ token_counts_by_expert = torch.bincount(flat_experts, minlength=num_experts).to(torch.int32)
307
+
308
+ # argsort with stable=True preserves order within each expert
309
+ gather_indices = flat_experts.argsort(stable=True)
310
+
311
+ return token_counts_by_expert, gather_indices
312
+
313
+
314
+ def _silu_and_mul(x):
315
+ """Fused SiLU activation and element-wise multiply for gate/up projections."""
316
+ gate, up = x.chunk(2, dim=-1)
317
+ return F.silu(gate) * up
318
+
319
+
320
+ # ============================================================================
321
+ # Separated LoRA Helper Functions
322
+ # ============================================================================
323
+
324
+
325
+ def _has_lora_adapters(param) -> bool:
326
+ """Check if parameter has active LoRA adapters (PEFT ParamWrapper)."""
327
+ # Check if this is a PEFT LoRA wrapper
328
+ if not hasattr(param, "lora_A") or not hasattr(param, "lora_B"):
329
+ return False
330
+ if hasattr(param, "disable_adapters") and param.disable_adapters:
331
+ return False
332
+ if hasattr(param, "merged") and param.merged:
333
+ return False
334
+ return len(param.lora_A) > 0
335
+
336
+
337
+ def _extract_lora_from_wrapper(
338
+ wrapper, adapter_name: str = "default", experts_module=None
339
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]:
340
+ """
341
+ Extract LoRA weights from PEFT ParamWrapper for MoE separated computation.
342
+
343
+ PEFT ParamWrapper for 3D parameters creates:
344
+ - lora_A: nn.Linear(in_dim, E*R) -> weight: (E*R, in_dim)
345
+ - lora_B: nn.Linear(E*R, out_dim) -> weight: (out_dim, E*R)
346
+
347
+ For grouped_mm: X @ first_weight @ second_weight
348
+
349
+ STANDARD FORMAT (Qwen3-MoE): weights stored as (E, out_dim, in_dim) for F.linear
350
+ gate_up_proj: (E, 2*I, H) - input X is (N, H), output is (N, 2*I)
351
+ down_proj: (E, H, I) - input X is (N, I), output is (N, H)
352
+
353
+ For gate_up with (E, 2*I, H):
354
+ lora_A: (E*R, H), lora_B: (2*I, E*R)
355
+ Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
356
+ first_weight from lora_A: (E*R, H) -> (E, H, R) after view/permute
357
+ second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) after view/permute
358
+
359
+ TRANSPOSED FORMAT (Qwen3-VL-MoE): weights stored as (E, in_dim, out_dim) for grouped_mm
360
+ gate_up_proj: (E, H, 2*I) - input X is (N, H), output is (N, 2*I)
361
+ down_proj: (E, I, H) - input X is (N, I), output is (N, H)
362
+
363
+ For gate_up with (E, H, 2*I):
364
+ lora_A: (E*R, H), lora_B: (2*I, E*R)
365
+ Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
366
+ first_weight from lora_A: (E*R, H) -> (E, H, R)
367
+ second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I)
368
+
369
+ Returns:
370
+ (first_weight, second_weight, scaling, num_experts) or None
371
+ """
372
+ # This Unsloth Zoo code section is licensed under AGPL3
373
+
374
+ try:
375
+ if not hasattr(wrapper, "lora_A") or not hasattr(wrapper, "lora_B"):
376
+ return None
377
+
378
+ if hasattr(wrapper, "disable_adapters") and wrapper.disable_adapters:
379
+ return None
380
+ if hasattr(wrapper, "merged") and wrapper.merged:
381
+ return None
382
+
383
+ if not wrapper.lora_A:
384
+ return None
385
+
386
+ if adapter_name not in wrapper.lora_A:
387
+ adapter_name = list(wrapper.lora_A.keys())[0]
388
+
389
+ lora_A_module = wrapper.lora_A[adapter_name]
390
+ lora_B_module = wrapper.lora_B[adapter_name]
391
+
392
+ weight_A = lora_A_module.weight # (E*R, dim1)
393
+ weight_B = lora_B_module.weight # (dim2, E*R)
394
+ scaling = wrapper.scaling[adapter_name]
395
+ num_experts = getattr(wrapper, "num_experts", 1)
396
+
397
+ # GET EXPERTS MODULE TO CHECK FOR REGISTERED EXTRACTOR
398
+ if experts_module is None:
399
+ experts_module = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
400
+
401
+ # Check for model-specific LoRA extractor attached to the experts module
402
+ extractor_fn = getattr(experts_module, "_unsloth_lora_extractor_fn", None)
403
+
404
+ if extractor_fn is not None:
405
+ return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts)
406
+
407
+ # DEFAULT BEHAVIOR (Standard Format / Non-MoE)
408
+ if num_experts > 1:
409
+ total_rank = weight_A.shape[0]
410
+ rank_per_expert = total_rank // num_experts
411
+ dim1 = weight_A.shape[1]
412
+ dim2 = weight_B.shape[0]
413
+
414
+ # STANDARD FORMAT (Qwen3-MoE / GLM4):
415
+ # Base weights are (E, out_dim, in_dim) for F.linear.
416
+ # LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R).
417
+ # We need X @ (E, in_dim, R) @ (E, R, out_dim).
418
+
419
+ # first_weight: (E, in_dim, R) - from lora_A
420
+ # second_weight: (E, R, out_dim) - from lora_B
421
+ first_weight = weight_A.view(num_experts, rank_per_expert, dim1)
422
+ first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R)
423
+
424
+ # second_weight (B): (E, R, out_dim)
425
+ second_weight = weight_B.view(dim2, num_experts, rank_per_expert)
426
+ second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2)
427
+ else:
428
+ # Non-MoE case: return weights for X @ A.T @ B.T
429
+ first_weight = weight_A.T # (dim1, R)
430
+ second_weight = weight_B.T # (R, dim2)
431
+
432
+ return first_weight, second_weight, scaling, num_experts
433
+ except Exception:
434
+ return None
435
+
436
+
437
+ def _extract_lora_weights(
438
+ param, adapter_name: str = "default", num_experts: int = None, experts_module=None
439
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor, float]]:
440
+ """
441
+ Extract LoRA A and B weights from PEFT ParamWrapper.
442
+
443
+ This is a compatibility wrapper around _extract_lora_from_wrapper.
444
+ Use _extract_lora_from_wrapper directly for new code.
445
+
446
+ Returns:
447
+ (first_weight, second_weight, scaling) for (X @ first) @ second
448
+ """
449
+ # This Unsloth Zoo code section is licensed under AGPL3
450
+
451
+ # Set num_experts on param if provided, so _extract_lora_from_wrapper can use it
452
+ if num_experts is not None and not hasattr(param, "num_experts"):
453
+ param.num_experts = num_experts
454
+
455
+ result = _extract_lora_from_wrapper(param, adapter_name, experts_module=experts_module)
456
+ if result is None:
457
+ return None
458
+ # Return first 3 elements (first_weight, second_weight, scaling) without num_experts
459
+ return result[0], result[1], result[2]
460
+
461
+
462
+ def _get_base_weight(param):
463
+ """Get base weight from potentially wrapped parameter or module."""
464
+ # This Unsloth Zoo code section is licensed under AGPL3
465
+
466
+ # Recursively unwrap PEFT layers
467
+ while hasattr(param, "base_layer"):
468
+ param = param.base_layer
469
+
470
+ if hasattr(param, "get_param"):
471
+ return param.get_param()
472
+
473
+ # Handle Modules (Linear, etc.)
474
+ if hasattr(param, "weight"):
475
+ return param.weight
476
+
477
+ return param
478
+
479
+
480
+ def _get_lora_wrapper_for_param(experts_module, param_name):
481
+ """
482
+ Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj).
483
+ Uses the explicit key stored in __dict__ if available.
484
+ Does NOT lazily setup wrappers as that requires traversing logic not present here.
485
+ """
486
+ # This Unsloth Zoo code section is licensed under AGPL3
487
+
488
+ if hasattr(experts_module, f"{param_name}_lora_wrapper"):
489
+ return getattr(experts_module, f"{param_name}_lora_wrapper")
490
+
491
+ # Check simple attributes if it's directly wrapped
492
+ if hasattr(experts_module, param_name):
493
+ attr = getattr(experts_module, param_name)
494
+ if hasattr(attr, "lora_A"): # Is a ParamWrapper
495
+ return attr
496
+
497
+ return None
498
+
499
+
500
+ def native_moe_grouped_mm(
501
+ inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
502
+ ) -> torch.Tensor:
503
+ """
504
+ Native implementation using grouped_mm with backward fix.
505
+
506
+ Uses custom autograd function to avoid PyTorch's grouped_mm backward stride bug.
507
+ """
508
+ return _grouped_mm_with_backward_fix(inputs, weight, offsets)
509
+
510
+
511
+ def _apply_lora_grouped_mm(
512
+ inputs: torch.Tensor,
513
+ lora_B: torch.Tensor,
514
+ lora_A: torch.Tensor,
515
+ offsets: torch.Tensor,
516
+ scaling: float,
517
+ grouped_mm_func=native_moe_grouped_mm,
518
+ ) -> torch.Tensor:
519
+ """
520
+ Apply LoRA using grouped GEMM: result = ((X @ B) @ A) * scaling
521
+
522
+ Args:
523
+ inputs: (total_tokens, in_dim)
524
+ lora_B: (num_experts, in_dim, rank) - First projection
525
+ lora_A: (num_experts, rank, out_dim) - Second projection
526
+ offsets: Grouped GEMM offsets
527
+ scaling: LoRA scaling factor
528
+ grouped_mm_func: Function to use for grouped GEMM (default: native_moe_grouped_mm)
529
+ """
530
+ # This Unsloth Zoo code section is licensed under AGPL3
531
+
532
+ # 1. First Matmul (X @ B)
533
+ # lora_B is (E, in_dim, R)
534
+ # Native needs (E, in_dim, R) -> No Transpose
535
+ lora_intermediate = grouped_mm_func(inputs, lora_B.contiguous(), offsets)
536
+
537
+ # 2. Second Matmul (result @ A)
538
+ # lora_A is (E, R, out_dim)
539
+ # Native needs (E, R, out_dim) -> No Transpose
540
+ lora_delta = grouped_mm_func(lora_intermediate, lora_A.contiguous(), offsets)
541
+
542
+ return lora_delta * scaling
543
+
544
+
545
+ def _should_use_separated_lora() -> bool:
546
+ """
547
+ Check if separated LoRA approach should be used (default: True).
548
+ Set UNSLOTH_MOE_LORA_MERGED=1 to use merged approach instead.
549
+ """
550
+ return os.environ.get("UNSLOTH_MOE_LORA_MERGED", "0") != "1"
551
+
552
+
553
+ # ============================================================================
554
+ # Model-specific Weight Preprocessing Hooks
555
+ # ============================================================================
556
+ # Each model can register its own preprocessing function for weight transposition.
557
+ # This allows the generic backend to work with different model weight layouts.
558
+
559
+ _WEIGHT_PREPROCESSORS = {}
560
+
561
+
562
+ def register_weight_preprocessor(model_type: str, preprocessor_fn):
563
+ """
564
+ Register a weight preprocessor for a specific model type.
565
+
566
+ Args:
567
+ model_type: Model identifier (e.g., "qwen3_moe", "qwen3_vl_moe")
568
+ preprocessor_fn: Function(weight, proj_type, hidden_dim) -> processed_weight
569
+ proj_type is "gate_up" or "down"
570
+ """
571
+ _WEIGHT_PREPROCESSORS[model_type] = preprocessor_fn
572
+
573
+
574
+ def get_weight_preprocessor(model_type: str):
575
+ """Get registered weight preprocessor for model type."""
576
+ return _WEIGHT_PREPROCESSORS.get(model_type)
577
+
578
+
579
+ def preprocess_weight(
580
+ weight: torch.Tensor, proj_type: str, hidden_dim: int, model_type=None
581
+ ):
582
+ """
583
+ Preprocess weight tensor for grouped_mm compatibility.
584
+
585
+ Uses model-specific preprocessor if registered, otherwise uses default logic.
586
+
587
+ Args:
588
+ weight: Weight tensor (E, dim1, dim2) or similar
589
+ proj_type: "gate_up" or "down"
590
+ hidden_dim: Hidden dimension for shape inference
591
+ model_type: Optional model type to use specific preprocessor
592
+
593
+ Returns:
594
+ Weight tensor in (E, in_dim, out_dim) format for grouped_mm
595
+ """
596
+ # This Unsloth Zoo code section is licensed under AGPL3
597
+
598
+ if model_type and model_type in _WEIGHT_PREPROCESSORS:
599
+ return _WEIGHT_PREPROCESSORS[model_type](weight, proj_type, hidden_dim)
600
+
601
+ # Default preprocessing: check if transposition is needed
602
+ if proj_type == "gate_up":
603
+ # For gate_up, we need (E, hidden_dim, 2*intermediate)
604
+ if weight.shape[1] == hidden_dim:
605
+ return weight
606
+ else:
607
+ return weight.transpose(-2, -1)
608
+ else: # down
609
+ # For down, we need (E, intermediate, hidden_dim)
610
+ if weight.shape[2] == hidden_dim:
611
+ return weight
612
+ else:
613
+ return weight.transpose(-2, -1)
614
+
615
+
616
+ # ============================================================================
617
+ # Generic MoE Detection and ParamWrapper Patching
618
+ # ============================================================================
619
+
620
+
621
+ def _is_moe_experts_module(module) -> bool:
622
+ """
623
+ Check if module is an MoE experts layer (generic, not model-specific).
624
+
625
+ Detects modules with stacked expert weights as 3D nn.Parameter:
626
+ - gate_up_proj/down_proj pattern (Qwen3-MoE, Qwen3-VL-MoE, etc.)
627
+ - w1/w2/w3 pattern (older MoE models)
628
+ """
629
+ # This Unsloth Zoo code section is licensed under AGPL3
630
+
631
+ import torch.nn as nn
632
+
633
+ # Check for gate_up_proj pattern
634
+ # After PEFT's nn.utils.parametrize wrapping, accessing gate_up_proj
635
+ # returns torch.Tensor (not nn.Parameter), so we must accept both.
636
+ if hasattr(module, "gate_up_proj"):
637
+ param = module.gate_up_proj
638
+ # 4-bit parameters are packed into 2D tensors (n_params, 1) or similar.
639
+ # Standard MoE weights are 3D (num_experts, in, out).
640
+ if isinstance(param, (nn.Parameter, torch.Tensor)) and param.ndim in (2, 3):
641
+ return True
642
+
643
+ # Check for w1/w2 pattern (separate gate/up projections)
644
+ if hasattr(module, "w1") and hasattr(module, "w2"):
645
+ w1 = module.w1
646
+ if isinstance(w1, (nn.Parameter, torch.Tensor)) and w1.ndim in (2, 3):
647
+ return True
648
+
649
+ return False
650
+
651
+
652
+ # Aliases for compatibility with gpt_oss.py
653
+ _get_moe_lora_weights = _extract_lora_from_wrapper
654
+
655
+
656
+ # Store original ParamWrapper.forward for fallback
657
+ _original_param_wrapper_forward = None
658
+
659
+
660
+ def _patched_param_wrapper_forward(
661
+ self, x: torch.Tensor, *args, **kwargs
662
+ ) -> torch.Tensor:
663
+ """
664
+ Patched ParamWrapper.forward for MoE separated LoRA.
665
+
666
+ For MoE expert modules:
667
+ - Bypasses PEFTs _activate_lora parametrization context
668
+ - Stores LoRA data by parameter_name for forward_native_grouped_mm to use
669
+
670
+ For non-MoE modules:
671
+ - Falls back to original PEFT forward
672
+ """
673
+ # This Unsloth Zoo code section is licensed under AGPL3
674
+
675
+ # CRITICAL: Use self.base_layer for forward call (immediate parent)
676
+ # NOT self.get_base_layer() which recursively traverses to deepest layer!
677
+ # The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts
678
+ immediate_base_layer = self.base_layer
679
+
680
+ # For storing LoRA data, we DO need the actual experts module
681
+ # Use get_base_layer() to find it (recursive traversal is correct here)
682
+ experts_module = self.get_base_layer()
683
+
684
+ use_separated = _should_use_separated_lora()
685
+ param_name = getattr(self, "parameter_name", None)
686
+
687
+ # Check if this is an MoE experts module that should use separated LoRA
688
+ if (
689
+ use_separated
690
+ and param_name in ("gate_up_proj", "down_proj")
691
+ and _is_moe_experts_module(experts_module)
692
+ ):
693
+ # MoE experts: bypass PEFT's _activate_lora, use separated computation
694
+
695
+ # Check adapter state
696
+ if self.disable_adapters:
697
+ if self.merged:
698
+ self.unmerge()
699
+ return immediate_base_layer(x, *args, **kwargs)
700
+
701
+ if self.merged:
702
+ return immediate_base_layer(x, *args, **kwargs)
703
+
704
+ # Ensure wrapper.num_experts is set for LoRA weight reshaping
705
+ if not hasattr(self, "num_experts"):
706
+ if hasattr(experts_module, "num_experts"):
707
+ self.num_experts = experts_module.num_experts
708
+ elif hasattr(experts_module, param_name):
709
+ p = getattr(experts_module, param_name)
710
+ if hasattr(p, "shape") and len(p.shape) >= 1:
711
+ self.num_experts = p.shape[0]
712
+
713
+ # Extract LoRA for this specific parameter
714
+ lora_data = _extract_lora_from_wrapper(self)
715
+
716
+ if lora_data is not None and param_name:
717
+ # Store LoRA data on the EXPERTS MODULE (not base_layer)
718
+ # e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj
719
+ lora_attr = f"_unsloth_lora_{param_name}"
720
+ setattr(experts_module, lora_attr, lora_data)
721
+
722
+ try:
723
+ # Call IMMEDIATE base_layer to preserve wrapper chain
724
+ # (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts)
725
+ result = immediate_base_layer(x, *args, **kwargs)
726
+ finally:
727
+ # Clean up
728
+ if param_name:
729
+ lora_attr = f"_unsloth_lora_{param_name}"
730
+ if hasattr(experts_module, lora_attr):
731
+ delattr(experts_module, lora_attr)
732
+
733
+ return result
734
+
735
+ # Non-MoE: use original PEFT forward with _activate_lora
736
+ return _original_param_wrapper_forward(self, x, *args, **kwargs)
737
+
738
+
739
+ def patch_param_wrapper_for_moe():
740
+ """
741
+ Patch PEFT's ParamWrapper.forward to use separated LoRA for MoE.
742
+
743
+ This should be called after PEFT is imported.
744
+ """
745
+ # This Unsloth Zoo code section is licensed under AGPL3
746
+
747
+ global _original_param_wrapper_forward
748
+
749
+ module = _load_cached_moe_utils_module()
750
+ if module is not None and hasattr(module, "patch_param_wrapper_for_moe"):
751
+ try:
752
+ return module.patch_param_wrapper_for_moe()
753
+ except Exception:
754
+ pass
755
+
756
+ try:
757
+ from peft.tuners.lora.layer import ParamWrapper
758
+
759
+ # Store original forward
760
+ if _original_param_wrapper_forward is None:
761
+ _original_param_wrapper_forward = ParamWrapper.forward
762
+
763
+ # Patch with our version
764
+ ParamWrapper.forward = _patched_param_wrapper_forward
765
+
766
+ return True
767
+ except ImportError:
768
+ return False
769
+
770
+
771
+ def forward_native_grouped_mm(
772
+ self,
773
+ hidden_states: torch.Tensor,
774
+ top_k_index: torch.Tensor,
775
+ top_k_weights: torch.Tensor,
776
+ ) -> torch.Tensor:
777
+ """
778
+ Native Pytorch grouped GEMM MoE forward pass.
779
+ Uses torch._grouped_mm which is significantly faster than loop and works without Triton dependencies.
780
+ Requires torch._grouped_mm support (verified via runtime check).
781
+ """
782
+ # This Unsloth Zoo code section is licensed under AGPL3
783
+
784
+ # Runtime safety check - defense in depth
785
+ if not _check_torch_grouped_mm_supported():
786
+ major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
787
+ raise RuntimeError(
788
+ f"torch._grouped_mm is not supported on this device (Compute Capability {major}.{minor}). "
789
+ f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend."
790
+ )
791
+
792
+ is_2d_input = hidden_states.dim() == 2
793
+ if is_2d_input:
794
+ sequence_length, hidden_dim = hidden_states.shape
795
+ batch_size = 1
796
+ else:
797
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
798
+
799
+ hidden_states = hidden_states.view(-1, hidden_dim)
800
+
801
+ # 1. Calculate routing
802
+ flat_top_k = top_k_index.view(-1)
803
+ num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int()
804
+
805
+ # 2. Sort indices to group tokens by expert
806
+ sorted_indices = torch.argsort(flat_top_k, stable=True)
807
+ token_indices = sorted_indices // top_k_index.shape[-1]
808
+
809
+ # 3. Permute Input
810
+ # We need to gather inputs. Since we may have expanded top_k, we use token_indices to map back to original input
811
+ permuted_input = hidden_states[token_indices]
812
+
813
+ # 4. Prepare Grouped MM arguments
814
+ offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
815
+
816
+ # ========================================================================
817
+ # Gate + Up projection with optional separated LoRA (DEFAULT)
818
+ # ========================================================================
819
+ use_separated_lora = _should_use_separated_lora()
820
+ gate_up_lora = None
821
+
822
+ # Check for injected LoRA data from patched ParamWrapper (preferred path)
823
+ if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None:
824
+ gate_up_lora = self._unsloth_lora_gate_up_proj[
825
+ :3
826
+ ] # (first_weight, second_weight, scaling)
827
+ # Fallback: check parameter directly (for older wrapping patterns)
828
+ elif (
829
+ use_separated_lora
830
+ and hasattr(self, "gate_up_proj")
831
+ and _has_lora_adapters(self.gate_up_proj)
832
+ ):
833
+ gate_up_lora = _extract_lora_weights(
834
+ self.gate_up_proj, num_experts=self.num_experts, experts_module=self
835
+ )
836
+
837
+ if hasattr(self, "gate_up_proj"):
838
+ # Get base weights (raw, without LoRA)
839
+ gate_up_base = _get_base_weight(self.gate_up_proj)
840
+
841
+ # Get model type for preprocessing (if registered)
842
+ model_type = getattr(self, "_unsloth_model_type", None)
843
+
844
+ # Handle different weight shapes using preprocessor
845
+ # torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view.
846
+ w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type)
847
+ # Base forward: X @ W
848
+ mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
849
+
850
+ # Add separated LoRA contribution: + ((X @ first) @ second) * scaling
851
+ # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
852
+ if gate_up_lora is not None:
853
+ first_weight, second_weight, scaling = gate_up_lora
854
+
855
+ # Cast to input dtype (LoRA weights are float32, input may be bfloat16)
856
+ # Ensure contiguous for grouped_mm alignment requirements
857
+ first_weight = first_weight.to(permuted_input.dtype).contiguous()
858
+ second_weight = second_weight.to(permuted_input.dtype).contiguous()
859
+
860
+ # Step 1: permuted_input @ first_weight
861
+ try:
862
+ lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets)
863
+ lora_out = lora_out.contiguous()
864
+ except RuntimeError as e:
865
+ raise e
866
+
867
+ # Step 2: result @ second_weight
868
+ # Handle unaligned O dimension or other grouped_mm failures
869
+ try:
870
+ if second_weight.shape[-1] % 8 != 0:
871
+ pad_size = 8 - (second_weight.shape[-1] % 8)
872
+ second_weight_padded = F.pad(
873
+ second_weight, (0, pad_size)
874
+ ).contiguous()
875
+ lora_delta = _grouped_mm_with_backward_fix(
876
+ lora_out, second_weight_padded, offsets
877
+ )
878
+ lora_delta = lora_delta[:, :-pad_size]
879
+ else:
880
+ lora_delta = _grouped_mm_with_backward_fix(
881
+ lora_out, second_weight, offsets
882
+ )
883
+ except RuntimeError:
884
+ # Fallback to manual loop if grouped_mm fails (e.g. stride alignment)
885
+ lora_delta = torch.empty(
886
+ (lora_out.shape[0], second_weight.shape[-1]),
887
+ dtype=lora_out.dtype,
888
+ device=lora_out.device,
889
+ )
890
+ cpu_offsets = offsets.cpu().tolist()
891
+ prev_offset = 0
892
+ for i, end in enumerate(cpu_offsets):
893
+ if prev_offset < end:
894
+ lora_delta[prev_offset:end] = torch.matmul(
895
+ lora_out[prev_offset:end], second_weight[i]
896
+ )
897
+ prev_offset = end
898
+
899
+ # Add scaled LoRA contribution
900
+ mm1_out = mm1_out + lora_delta * scaling
901
+
902
+ if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
903
+ num_repeats = num_tokens_per_expert.to(self.gate_up_proj_bias.device)
904
+ bias_expanded = self.gate_up_proj_bias.repeat_interleave(num_repeats, dim=0)
905
+ mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype)
906
+
907
+ if "GptOssExperts" in self.__class__.__name__:
908
+ gate = mm1_out[..., ::2]
909
+ up = mm1_out[..., 1::2]
910
+ else:
911
+ gate, up = mm1_out.chunk(2, dim=-1)
912
+
913
+ elif hasattr(self, "w1") and hasattr(self, "w3"):
914
+ # Separate w1/w3 weights (older models)
915
+ w1_base = _get_base_weight(self.w1)
916
+ w3_base = _get_base_weight(self.w3)
917
+
918
+ w1 = w1_base.transpose(-2, -1)
919
+ w3 = w3_base.transpose(-2, -1)
920
+
921
+ gate = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
922
+ up = _grouped_mm_with_backward_fix(permuted_input, w3, offsets)
923
+
924
+ # Add LoRA for w1 and w3 separately if present
925
+ if use_separated_lora:
926
+ if _has_lora_adapters(self.w1):
927
+ w1_lora = _extract_lora_weights(self.w1, experts_module=self)
928
+ if w1_lora is not None:
929
+ lora_A, lora_B, scaling = w1_lora
930
+ lora_A_t = lora_A.transpose(-2, -1)
931
+ lora_A_out = _grouped_mm_with_backward_fix(
932
+ permuted_input, lora_A_t, offsets
933
+ )
934
+ lora_B_t = lora_B.transpose(-2, -1)
935
+ lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
936
+ gate = gate + lora_B_out * scaling
937
+
938
+ if _has_lora_adapters(self.w3):
939
+ w3_lora = _extract_lora_weights(self.w3, experts_module=self)
940
+ if w3_lora is not None:
941
+ lora_A, lora_B, scaling = w3_lora
942
+ lora_A_t = lora_A.transpose(-2, -1)
943
+ lora_A_out = _grouped_mm_with_backward_fix(
944
+ permuted_input, lora_A_t, offsets
945
+ )
946
+ lora_B_t = lora_B.transpose(-2, -1)
947
+ lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
948
+ up = up + lora_B_out * scaling
949
+ else:
950
+ raise AttributeError("MoE layer must have 'gate_up_proj' or 'w1'/'w3'.")
951
+
952
+ # Activation
953
+ if "GptOssExperts" in self.__class__.__name__:
954
+ # Custom activation from GptOss
955
+ limit = getattr(self, "limit", 7.0)
956
+ alpha = getattr(self, "alpha", 1.702)
957
+
958
+ gate = gate.clamp(min=None, max=limit)
959
+ up = up.clamp(min=-limit, max=limit)
960
+ glu = gate * torch.sigmoid(gate * alpha)
961
+ inter = (up + 1.0) * glu
962
+ else:
963
+ inter = F.silu(gate) * up
964
+
965
+ # ========================================================================
966
+ # Down projection with optional separated LoRA (DEFAULT)
967
+ # ========================================================================
968
+ down_lora = None
969
+
970
+ # Check for injected LoRA data from patched ParamWrapper (preferred path)
971
+ if getattr(self, "_unsloth_lora_down_proj", None) is not None:
972
+ down_lora = self._unsloth_lora_down_proj[
973
+ :3
974
+ ] # (first_weight, second_weight, scaling)
975
+ # Fallback: check parameter directly (for older wrapping patterns)
976
+ elif (
977
+ use_separated_lora
978
+ and hasattr(self, "down_proj")
979
+ and _has_lora_adapters(self.down_proj)
980
+ ):
981
+ down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts, experts_module=self)
982
+
983
+ if hasattr(self, "down_proj"):
984
+ # Get base weights
985
+ down_base = _get_base_weight(self.down_proj)
986
+
987
+ # Get model type for preprocessing (if registered)
988
+ model_type = getattr(self, "_unsloth_model_type", None)
989
+
990
+ # Handle different weight shapes using preprocessor
991
+ w2 = preprocess_weight(down_base, "down", hidden_dim, model_type)
992
+
993
+ # Base forward
994
+ mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
995
+
996
+ # Add separated LoRA contribution if present
997
+ # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
998
+ if down_lora is not None:
999
+ first_weight, second_weight, scaling = down_lora
1000
+
1001
+ # Cast to input dtype (LoRA weights are float32, input may be bfloat16)
1002
+ first_weight = first_weight.to(inter.dtype).contiguous()
1003
+ second_weight = second_weight.to(inter.dtype).contiguous()
1004
+
1005
+ # Step 1: inter @ first_weight
1006
+ lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets)
1007
+ lora_out = lora_out.contiguous()
1008
+
1009
+ # Step 2: result @ second_weight
1010
+ try:
1011
+ lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
1012
+ except RuntimeError:
1013
+ # Fallback to manual loop
1014
+ lora_delta = torch.empty(
1015
+ (lora_out.shape[0], second_weight.shape[-1]),
1016
+ dtype=lora_out.dtype,
1017
+ device=lora_out.device,
1018
+ )
1019
+ cpu_offsets = offsets.cpu().tolist()
1020
+ prev_offset = 0
1021
+ for i, end in enumerate(cpu_offsets):
1022
+ if prev_offset < end:
1023
+ lora_delta[prev_offset:end] = torch.matmul(
1024
+ lora_out[prev_offset:end], second_weight[i]
1025
+ )
1026
+ prev_offset = end
1027
+
1028
+ # Add scaled LoRA contribution
1029
+ mm2_out = mm2_out + lora_delta * scaling
1030
+
1031
+ if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
1032
+ bias_expanded = self.down_proj_bias.repeat_interleave(
1033
+ num_tokens_per_expert.to(self.down_proj_bias.device), dim=0
1034
+ ).to(mm2_out.device)
1035
+ mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype)
1036
+
1037
+ elif hasattr(self, "w2"):
1038
+ w2_base = _get_base_weight(self.w2)
1039
+ w2 = w2_base.transpose(-2, -1)
1040
+
1041
+ # Base forward
1042
+ mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
1043
+
1044
+ # Add LoRA if present
1045
+ if use_separated_lora and _has_lora_adapters(self.w2):
1046
+ w2_lora = _extract_lora_weights(self.w2, experts_module=self)
1047
+ if w2_lora is not None:
1048
+ lora_A, lora_B, scaling = w2_lora
1049
+ lora_A_t = lora_A.transpose(-2, -1).contiguous()
1050
+ lora_A_out = _grouped_mm_with_backward_fix(inter, lora_A_t, offsets)
1051
+ lora_B_t = lora_B.transpose(-2, -1).contiguous()
1052
+ lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
1053
+ mm2_out = mm2_out + lora_B_out * scaling
1054
+ else:
1055
+ raise AttributeError("MoE layer must have 'down_proj' or 'w2'.")
1056
+
1057
+ # 5. Apply Routing Weights and Scatter Add (Reduce)
1058
+ flat_weights = top_k_weights.view(-1)
1059
+ permuted_weights = flat_weights[sorted_indices]
1060
+ mm2_out = mm2_out * permuted_weights.unsqueeze(-1)
1061
+
1062
+ final_hidden_states = torch.zeros(
1063
+ (batch_size * sequence_length, hidden_dim),
1064
+ dtype=hidden_states.dtype,
1065
+ device=hidden_states.device,
1066
+ )
1067
+
1068
+ final_hidden_states.index_add_(0, token_indices, mm2_out.to(hidden_states.dtype))
1069
+
1070
+ if is_2d_input:
1071
+ return final_hidden_states
1072
+
1073
+ return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
1074
+
1075
+
1076
+ def forward_triton_grouped_gemm(
1077
+ self,
1078
+ hidden_states: torch.Tensor,
1079
+ top_k_index: torch.Tensor,
1080
+ top_k_weights: torch.Tensor,
1081
+ ) -> torch.Tensor:
1082
+ """
1083
+ Grouped GEMM MoE forward pass using Triton kernels.
1084
+ Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin).
1085
+ """
1086
+ # This Unsloth Zoo code section is licensed under AGPL3
1087
+
1088
+ # Import grouped GEMM interface
1089
+ from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
1090
+
1091
+ # Import autotune cache
1092
+ from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels
1093
+
1094
+ # Helper to check TMA support - assumes helper function or just check directly
1095
+ # In original: it was a cached closure. Here we can use _supports_tma() directly
1096
+
1097
+ # nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this!
1098
+ # For now, let's attach it to self if possible, or use a global usage
1099
+ # Attaching to self is cleaner: self._unsloth_moe_configs
1100
+
1101
+ # Create expert mask and find which experts have tokens
1102
+
1103
+ if not hasattr(self, "_unsloth_moe_configs"):
1104
+ self._unsloth_moe_configs = None
1105
+
1106
+ use_separated_lora = _should_use_separated_lora()
1107
+
1108
+
1109
+ # Handle 3D inputs (batch_size, seq_len, hidden_dim)
1110
+ is_3d = hidden_states.dim() == 3
1111
+ if is_3d:
1112
+ batch_size, seq_len, hidden_dim = hidden_states.shape
1113
+ hidden_states = hidden_states.view(-1, hidden_dim)
1114
+ num_tokens = batch_size * seq_len
1115
+ # Also flatten top_k inputs if they are 3D
1116
+ if top_k_index.dim() == 3:
1117
+ top_k_index = top_k_index.view(-1, top_k_index.shape[-1])
1118
+ if top_k_weights.dim() == 3:
1119
+ top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1])
1120
+ else:
1121
+ num_tokens, hidden_dim = hidden_states.shape
1122
+
1123
+ top_k = top_k_index.shape[1]
1124
+
1125
+ # Cache model dimensions and kernel configs on first call
1126
+ if self._unsloth_moe_configs is None:
1127
+ intermediate_dim = self.gate_up_proj.shape[1] // 2
1128
+
1129
+ # Autotune first GEMM
1130
+ gemm1_configs = get_or_autotune_moe_kernels(
1131
+ num_experts=self.num_experts,
1132
+ hidden_dim=hidden_dim,
1133
+ intermediate_dim=intermediate_dim * 2,
1134
+ top_k=top_k,
1135
+ dtype=hidden_states.dtype,
1136
+ )
1137
+
1138
+ # Autotune second GEMM
1139
+ gemm2_configs = get_or_autotune_moe_kernels(
1140
+ num_experts=self.num_experts,
1141
+ hidden_dim=intermediate_dim,
1142
+ intermediate_dim=hidden_dim, # Output dim for 2nd GEMM is hidden_dim
1143
+ top_k=top_k,
1144
+ dtype=hidden_states.dtype,
1145
+ )
1146
+
1147
+ self._unsloth_moe_configs = (intermediate_dim, gemm1_configs, gemm2_configs)
1148
+
1149
+ # Clear autotuning memory overhead
1150
+ torch.cuda.empty_cache()
1151
+
1152
+ # Unpack cached configs
1153
+ intermediate_dim, gemm1_configs, gemm2_configs = self._unsloth_moe_configs
1154
+
1155
+ # Unpack specific kernel configs
1156
+ fwd_config_1, bwd_dX_config_1, bwd_dW_config_1 = gemm1_configs
1157
+ fwd_config_2, bwd_dX_config_2, bwd_dW_config_2 = gemm2_configs
1158
+
1159
+ # Compute routing indices for grouped GEMM
1160
+ token_counts_by_expert, gather_indices = _get_routing_indices(
1161
+ top_k_index, self.num_experts
1162
+ )
1163
+ offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32)
1164
+
1165
+ if self.gate_up_proj.shape[-1] == hidden_dim:
1166
+ w1 = self.gate_up_proj
1167
+ else:
1168
+ w1 = self.gate_up_proj.transpose(-2, -1).contiguous()
1169
+
1170
+ # First grouped GEMM: gate_up projection
1171
+ first_gemm_output = grouped_gemm(
1172
+ X=hidden_states,
1173
+ W=w1,
1174
+ m_sizes=token_counts_by_expert,
1175
+ topk=top_k,
1176
+ gather_indices=gather_indices,
1177
+ permute_x=True,
1178
+ permute_y=False,
1179
+ autotune=False, # We use cached configs
1180
+ kernel_config_fwd=fwd_config_1,
1181
+ kernel_config_bwd_dX=bwd_dX_config_1,
1182
+ kernel_config_bwd_dW=bwd_dW_config_1,
1183
+ is_first_gemm=True,
1184
+ )
1185
+
1186
+ # Apply SiLU activation and multiply gate with up
1187
+ intermediate = _silu_and_mul(first_gemm_output)
1188
+
1189
+ # Grouped GEMM 2: down projection
1190
+
1191
+ # Grouped GEMM 2: down projection
1192
+ # Prepare LoRA data
1193
+ down_lora = None
1194
+ if getattr(self, "_unsloth_lora_down_proj", None) is not None:
1195
+ down_lora = self._unsloth_lora_down_proj[:3]
1196
+ elif (
1197
+ use_separated_lora
1198
+ and hasattr(self, "down_proj")
1199
+ and _has_lora_adapters(self.down_proj)
1200
+ ):
1201
+ down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts)
1202
+
1203
+ if self.down_proj.shape[-1] == intermediate.shape[-1]:
1204
+ w2 = self.down_proj
1205
+ else:
1206
+ w2 = self.down_proj.transpose(-2, -1).contiguous()
1207
+
1208
+ second_gemm_output = grouped_gemm(
1209
+ X=intermediate,
1210
+ W=w2,
1211
+ m_sizes=token_counts_by_expert,
1212
+ topk=top_k,
1213
+ gather_indices=gather_indices,
1214
+ permute_x=False,
1215
+ permute_y=True,
1216
+ autotune=False, # We use cached configs
1217
+ kernel_config_fwd=fwd_config_2,
1218
+ kernel_config_bwd_dX=bwd_dX_config_2,
1219
+ kernel_config_bwd_dW=bwd_dW_config_2,
1220
+ is_first_gemm=False,
1221
+ )
1222
+
1223
+ # Add separated LoRA contribution for Down
1224
+ if down_lora is not None:
1225
+ first_weight, second_weight, scaling = down_lora
1226
+
1227
+ # Intermediate is already permuted from step 1.
1228
+ # Offsets are same.
1229
+
1230
+ first_weight = first_weight.to(intermediate.dtype)
1231
+ second_weight = second_weight.to(intermediate.dtype)
1232
+
1233
+ lora_delta = _apply_lora_grouped_mm(
1234
+ intermediate,
1235
+ first_weight,
1236
+ second_weight,
1237
+ offsets,
1238
+ scaling,
1239
+ grouped_mm_func=native_moe_grouped_mm
1240
+ )
1241
+
1242
+ second_gemm_output = second_gemm_output + lora_delta
1243
+
1244
+ # Apply routing weights and sum across top_k experts
1245
+ # Output shape: (num_tokens, top_k, hidden_dim) -> (num_tokens, hidden_dim)
1246
+ # Ensure top_k_weights matches dtype (can be float32 from softmax)
1247
+ top_k_weights_casted = top_k_weights.to(hidden_states.dtype)
1248
+ final_hidden_states = (
1249
+ second_gemm_output.view(num_tokens, top_k, hidden_dim)
1250
+ * top_k_weights_casted[..., None]
1251
+ )
1252
+ final_hidden_states = final_hidden_states.sum(dim=1)
1253
+
1254
+ if is_3d:
1255
+ final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
1256
+
1257
+ return final_hidden_states
1258
+
1259
+
1260
+ @torch.compiler.disable
1261
+ def forward_native_moe_loop(
1262
+ self,
1263
+ hidden_states: torch.Tensor,
1264
+ top_k_index: torch.Tensor,
1265
+ top_k_weights: torch.Tensor,
1266
+ ) -> torch.Tensor:
1267
+ """
1268
+ Loop-based MoE forward pass. Loops over experts that have tokens routed to them.
1269
+ Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow.
1270
+ """
1271
+ # This Unsloth Zoo code section is licensed under AGPL3
1272
+ final_hidden_states = torch.zeros_like(hidden_states)
1273
+
1274
+ # Create expert mask and find which experts have tokens
1275
+ with torch.no_grad():
1276
+ expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts)
1277
+ expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, n_tokens)
1278
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
1279
+
1280
+ # Only loop over experts that actually have tokens routed to them
1281
+ for expert_idx_t in expert_hit:
1282
+ expert_idx = expert_idx_t.item()
1283
+
1284
+ # Find which tokens are routed to this expert
1285
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
1286
+
1287
+ # Gather only the tokens for this expert
1288
+ current_state = hidden_states[token_idx]
1289
+
1290
+ # Compute gate_up projection for this expert only
1291
+ # Handle 'gate_up_proj' or 'w1'/'w3'
1292
+ if hasattr(self, "gate_up_proj"):
1293
+ gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(
1294
+ 2, dim=-1
1295
+ )
1296
+ else:
1297
+ gate = F.linear(current_state, self.w1[expert_idx])
1298
+ up = F.linear(current_state, self.w3[expert_idx])
1299
+
1300
+ current_hidden_states = self.act_fn(gate) * up
1301
+
1302
+ # Compute down projection for this expert only
1303
+ if hasattr(self, "down_proj"):
1304
+ current_hidden_states = F.linear(
1305
+ current_hidden_states, self.down_proj[expert_idx]
1306
+ )
1307
+ else:
1308
+ current_hidden_states = F.linear(current_hidden_states, self.w2[expert_idx])
1309
+
1310
+ # Apply routing weights
1311
+ current_hidden_states = (
1312
+ current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
1313
+ )
1314
+
1315
+ # Scatter back to final output
1316
+ final_hidden_states.index_add_(
1317
+ 0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
1318
+ )
1319
+
1320
+ return final_hidden_states
uv.lock CHANGED
The diff for this file is too large to render. See raw diff