Ev3Dev commited on
Commit
2bf5069
·
verified ·
1 Parent(s): ad39f2a

Upload folder using huggingface_hub

Browse files
_dashboard_state.json CHANGED
@@ -1,29 +1,29 @@
1
  {
2
- "timestamp": 1772986999.4539518,
3
  "step": 20,
4
  "episode_done": false,
5
- "cumulative_reward": -1.2923024829860532,
6
- "gen_time_s": 17.14,
7
- "model_response_raw": "{\n\"action_type\": \"design_followup_experiment\",\n\"method\": \"\",\n\"parameters\": {\n\"candidate_target_genes\": [],\n\"suggested_wet_lab_targets\": []\n},\n\"Justification\": \"The previous attempt failed due to incomplete genomic profiling required for robust marker identification via CRISPR/Cas9 editing followed by functional readout. Without identifying specific high-confidence target genes within the identified clusters through comprehensive transcriptomic characterization first, designing a focused knockdown screen lacks the necessary molecular basis.\",\n\"Confidence\": 0.85\n}",
8
  "model_thinking": "",
9
  "thinking_enabled": false,
10
  "task": {
11
- "problem_statement": "Validate candidate biomarker ERBB2 for breast cancer in breast tissue using single-cell RNA sequencing.",
12
  "organism": "human",
13
- "tissue": "breast",
14
  "modality": "scRNA-seq",
15
  "conditions": [
16
  "healthy",
17
- "breast_cancer"
18
  ],
19
- "budget_limit": 89988.7088221059,
20
- "time_limit_days": 160.54122628462622
21
  },
22
  "resources": {
23
  "budget_used": 28900.0,
24
- "budget_remaining": 61088.71,
25
  "time_used_days": 18.0,
26
- "time_remaining_days": 142.5,
27
  "samples_consumed": 0,
28
  "compute_hours_used": 12.0
29
  },
@@ -32,9 +32,9 @@
32
  "step_index": 1,
33
  "action_type": "collect_sample",
34
  "method": null,
35
- "output_summary": "Collected 6 samples (quality=0.42)",
36
  "success": true,
37
- "quality_score": 0.422,
38
  "resource_cost": 5000.0,
39
  "time_cost_days": 7.0
40
  },
@@ -62,9 +62,9 @@
62
  "step_index": 4,
63
  "action_type": "prepare_library",
64
  "method": null,
65
- "output_summary": "Library prepared (complexity=0.59)",
66
  "success": true,
67
- "quality_score": 0.595,
68
  "resource_cost": 8000.0,
69
  "time_cost_days": 3.0
70
  },
@@ -72,42 +72,42 @@
72
  "step_index": 5,
73
  "action_type": "sequence_cells",
74
  "method": null,
75
- "output_summary": "Sequenced 8578 cells, 9244 genes detected, median UMI=3090",
76
  "success": true,
77
- "quality_score": 0.359,
78
  "resource_cost": 15000.0,
79
  "time_cost_days": 5.0
80
  },
81
  {
82
  "step_index": 6,
 
 
 
 
 
 
 
 
 
 
83
  "action_type": "run_qc",
84
  "method": null,
85
  "output_summary": "QC metrics computed",
86
  "success": true,
87
- "quality_score": 0.821,
88
  "resource_cost": 100.0,
89
  "time_cost_days": 0.5
90
  },
91
  {
92
- "step_index": 7,
93
  "action_type": "filter_data",
94
  "method": null,
95
- "output_summary": "Filtered 8578 \u2192 7050 cells (82% retained)",
96
  "success": true,
97
- "quality_score": 0.822,
98
  "resource_cost": 50.0,
99
  "time_cost_days": 0.2
100
  },
101
- {
102
- "step_index": 8,
103
- "action_type": "cluster_cells",
104
- "method": null,
105
- "output_summary": "Action blocked: Cannot cluster before normalisation",
106
- "success": false,
107
- "quality_score": 1.0,
108
- "resource_cost": 100.0,
109
- "time_cost_days": 0.5
110
- },
111
  {
112
  "step_index": 9,
113
  "action_type": "normalize_data",
@@ -120,11 +120,11 @@
120
  },
121
  {
122
  "step_index": 10,
123
- "action_type": "marker_selection",
124
  "method": null,
125
- "output_summary": "Action blocked: Cannot select markers without DE results",
126
- "success": false,
127
- "quality_score": 1.0,
128
  "resource_cost": 100.0,
129
  "time_cost_days": 0.5
130
  },
@@ -132,20 +132,20 @@
132
  "step_index": 11,
133
  "action_type": "design_followup_experiment",
134
  "method": null,
135
- "output_summary": "Action blocked: Follow-up design without prior analysis is blocked; complete wet-lab and computational steps first",
136
- "success": false,
137
- "quality_score": 1.0,
138
  "resource_cost": 100.0,
139
  "time_cost_days": 0.5
140
  },
141
  {
142
  "step_index": 12,
143
- "action_type": "cluster_cells",
144
  "method": null,
145
- "output_summary": "Found 7 clusters",
146
- "success": true,
147
- "quality_score": 0.741,
148
- "resource_cost": 100.0,
149
  "time_cost_days": 0.5
150
  },
151
  {
@@ -160,26 +160,6 @@
160
  },
161
  {
162
  "step_index": 14,
163
- "action_type": "design_followup_experiment",
164
- "method": null,
165
- "output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
166
- "success": true,
167
- "quality_score": 0.28,
168
- "resource_cost": 100.0,
169
- "time_cost_days": 0.5
170
- },
171
- {
172
- "step_index": 15,
173
- "action_type": "synthesize_conclusion",
174
- "method": null,
175
- "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
176
- "success": false,
177
- "quality_score": 1.0,
178
- "resource_cost": 0.0,
179
- "time_cost_days": 0.5
180
- },
181
- {
182
- "step_index": 16,
183
  "action_type": "synthesize_conclusion",
184
  "method": null,
185
  "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
@@ -222,8 +202,8 @@
222
  "Cannot synthesise conclusion before inferring pathways or mechanisms"
223
  ],
224
  "uncertainty_summary": {
225
- "avg_uncertainty": 0.177,
226
- "avg_quality": 0.804
227
  },
228
  "reward_breakdown": {
229
  "validity": -1.0,
@@ -249,92 +229,83 @@
249
  "latent": {
250
  "cell_populations": [
251
  {
252
- "name": "luminal_epithelial",
253
- "proportion": 0.433,
254
  "marker_genes": [
255
- "KRT8",
256
- "KRT18",
257
- "EPCAM"
258
  ],
259
- "state": "normal"
260
  },
261
  {
262
- "name": "basal_epithelial",
263
- "proportion": 0.157,
264
  "marker_genes": [
265
- "KRT14",
266
- "KRT5",
267
- "TP63"
268
  ],
269
  "state": "normal"
270
  },
271
  {
272
- "name": "fibroblast",
273
- "proportion": 0.119,
274
  "marker_genes": [
275
- "COL1A1",
276
- "COL3A1",
277
- "FAP"
278
  ],
279
  "state": "quiescent"
280
  },
281
  {
282
- "name": "T_cell",
283
- "proportion": 0.105,
284
- "marker_genes": [
285
- "CD3D",
286
- "CD3E",
287
- "CD8A"
288
- ],
289
- "state": "activated"
290
- },
291
- {
292
- "name": "macrophage",
293
- "proportion": 0.096,
294
  "marker_genes": [
295
- "CD68",
296
- "CD163",
297
- "CSF1R"
298
  ],
299
- "state": "inflammatory"
300
  },
301
  {
302
- "name": "endothelial",
303
- "proportion": 0.09,
304
  "marker_genes": [
305
- "PECAM1",
306
- "VWF",
307
- "CDH5"
308
  ],
309
- "state": "quiescent"
310
  }
311
  ],
312
  "true_markers": [
313
- "ERBB2",
314
- "MKI67",
315
- "CD274",
316
- "VIM"
317
  ],
318
  "causal_mechanisms": [
319
- "ERBB2-driven proliferative signalling",
320
- "immune evasion via PD-L1 upregulation"
 
321
  ],
322
  "true_pathways": {
323
- "cell_cycle": 0.889,
324
- "PI3K_AKT_signalling": 0.803,
325
- "EMT": 0.757,
326
- "immune_checkpoint": 0.579,
327
- "estrogen_signalling": 0.644
328
  },
329
- "true_de_genes_count": 9,
330
- "true_regulatory_network_size": 9,
331
  "confounders": {},
332
- "n_true_cells": 14493,
333
  "technical": {
334
- "ambient_rna_fraction": 0.05700028722692205,
335
- "doublet_rate": 0.0716382392677839,
336
- "dropout_rate": 0.1822398381996976,
337
- "sample_quality": 0.7058144963381642,
338
  "library_complexity": 0.8,
339
  "capture_efficiency": 0.6
340
  },
@@ -359,8 +330,8 @@
359
  "followup_designed": true,
360
  "subagent_review_requested": false,
361
  "conclusion_reached": false,
362
- "n_cells_sequenced": 8578,
363
- "n_cells_after_filter": 7050,
364
  "n_clusters_found": "7",
365
  "n_de_genes_found": null,
366
  "n_markers_found": null
 
1
  {
2
+ "timestamp": 1772994420.314485,
3
  "step": 20,
4
  "episode_done": false,
5
+ "cumulative_reward": 1.8024931955050638,
6
+ "gen_time_s": 23.52,
7
+ "model_response_raw": "{\n\"action_type\": \"design_followup_experiment\",\n\"method\": \"\",\n\"parameters\": {\n\"disease_comparison_groups\": true,\n\"treatment_condition_focus\": false,\n\"criterion_for_target_genes\": \"Differentially Expressed Between Healthy vs AD Brain Cells\"\n},\n\"Justification\": \"The primary goal remains identifying DGEs across conditions rather than targeting specific known pathways first. While 'filter_data' was previously performed but failed due to lack of validated markers from previous attempts, re-running filtering based solely on technical/quality criteria would still leave ambiguous candidates if noise pe",
8
  "model_thinking": "",
9
  "thinking_enabled": false,
10
  "task": {
11
+ "problem_statement": "Identify differentially expressed genes between Alzheimer's disease and healthy brain tissue using single-cell RNA sequencing.",
12
  "organism": "human",
13
+ "tissue": "brain",
14
  "modality": "scRNA-seq",
15
  "conditions": [
16
  "healthy",
17
+ "Alzheimer"
18
  ],
19
+ "budget_limit": 84032.3891133686,
20
+ "time_limit_days": 116.86792941835967
21
  },
22
  "resources": {
23
  "budget_used": 28900.0,
24
+ "budget_remaining": 55132.39,
25
  "time_used_days": 18.0,
26
+ "time_remaining_days": 98.9,
27
  "samples_consumed": 0,
28
  "compute_hours_used": 12.0
29
  },
 
32
  "step_index": 1,
33
  "action_type": "collect_sample",
34
  "method": null,
35
+ "output_summary": "Collected 6 samples (quality=0.56)",
36
  "success": true,
37
+ "quality_score": 0.562,
38
  "resource_cost": 5000.0,
39
  "time_cost_days": 7.0
40
  },
 
62
  "step_index": 4,
63
  "action_type": "prepare_library",
64
  "method": null,
65
+ "output_summary": "Library prepared (complexity=0.75)",
66
  "success": true,
67
+ "quality_score": 0.751,
68
  "resource_cost": 8000.0,
69
  "time_cost_days": 3.0
70
  },
 
72
  "step_index": 5,
73
  "action_type": "sequence_cells",
74
  "method": null,
75
+ "output_summary": "Sequenced 4574 cells, 9256 genes detected, median UMI=3083",
76
  "success": true,
77
+ "quality_score": 0.47,
78
  "resource_cost": 15000.0,
79
  "time_cost_days": 5.0
80
  },
81
  {
82
  "step_index": 6,
83
+ "action_type": "filter_data",
84
+ "method": null,
85
+ "output_summary": "Action blocked: Cannot filter data before QC",
86
+ "success": false,
87
+ "quality_score": 1.0,
88
+ "resource_cost": 50.0,
89
+ "time_cost_days": 0.2
90
+ },
91
+ {
92
+ "step_index": 7,
93
  "action_type": "run_qc",
94
  "method": null,
95
  "output_summary": "QC metrics computed",
96
  "success": true,
97
+ "quality_score": 0.829,
98
  "resource_cost": 100.0,
99
  "time_cost_days": 0.5
100
  },
101
  {
102
+ "step_index": 8,
103
  "action_type": "filter_data",
104
  "method": null,
105
+ "output_summary": "Filtered 4574 \u2192 3996 cells (87% retained)",
106
  "success": true,
107
+ "quality_score": 0.874,
108
  "resource_cost": 50.0,
109
  "time_cost_days": 0.2
110
  },
 
 
 
 
 
 
 
 
 
 
111
  {
112
  "step_index": 9,
113
  "action_type": "normalize_data",
 
120
  },
121
  {
122
  "step_index": 10,
123
+ "action_type": "cluster_cells",
124
  "method": null,
125
+ "output_summary": "Found 7 clusters",
126
+ "success": true,
127
+ "quality_score": 0.79,
128
  "resource_cost": 100.0,
129
  "time_cost_days": 0.5
130
  },
 
132
  "step_index": 11,
133
  "action_type": "design_followup_experiment",
134
  "method": null,
135
+ "output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
136
+ "success": true,
137
+ "quality_score": 0.28,
138
  "resource_cost": 100.0,
139
  "time_cost_days": 0.5
140
  },
141
  {
142
  "step_index": 12,
143
+ "action_type": "synthesize_conclusion",
144
  "method": null,
145
+ "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
146
+ "success": false,
147
+ "quality_score": 1.0,
148
+ "resource_cost": 0.0,
149
  "time_cost_days": 0.5
150
  },
151
  {
 
160
  },
161
  {
162
  "step_index": 14,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  "action_type": "synthesize_conclusion",
164
  "method": null,
165
  "output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
 
202
  "Cannot synthesise conclusion before inferring pathways or mechanisms"
203
  ],
204
  "uncertainty_summary": {
205
+ "avg_uncertainty": 0.224,
206
+ "avg_quality": 0.814
207
  },
208
  "reward_breakdown": {
209
  "validity": -1.0,
 
229
  "latent": {
230
  "cell_populations": [
231
  {
232
+ "name": "excitatory_neuron",
233
+ "proportion": 0.349,
234
  "marker_genes": [
235
+ "SLC17A7",
236
+ "CAMK2A",
237
+ "NRGN"
238
  ],
239
+ "state": "stressed"
240
  },
241
  {
242
+ "name": "inhibitory_neuron",
243
+ "proportion": 0.209,
244
  "marker_genes": [
245
+ "GAD1",
246
+ "GAD2",
247
+ "SLC32A1"
248
  ],
249
  "state": "normal"
250
  },
251
  {
252
+ "name": "astrocyte",
253
+ "proportion": 0.211,
254
  "marker_genes": [
255
+ "GFAP",
256
+ "AQP4",
257
+ "SLC1A3"
258
  ],
259
  "state": "quiescent"
260
  },
261
  {
262
+ "name": "oligodendrocyte",
263
+ "proportion": 0.153,
 
 
 
 
 
 
 
 
 
 
264
  "marker_genes": [
265
+ "MBP",
266
+ "PLP1",
267
+ "MOG"
268
  ],
269
+ "state": "myelinating"
270
  },
271
  {
272
+ "name": "OPC",
273
+ "proportion": 0.078,
274
  "marker_genes": [
275
+ "PDGFRA",
276
+ "CSPG4",
277
+ "OLIG2"
278
  ],
279
+ "state": "progenitor"
280
  }
281
  ],
282
  "true_markers": [
283
+ "TREM2",
284
+ "APOE",
285
+ "GFAP",
286
+ "C1QA"
287
  ],
288
  "causal_mechanisms": [
289
+ "TREM2-mediated microglial activation in amyloid clearance",
290
+ "complement-driven synaptic pruning",
291
+ "reactive astrogliosis amplifying neuroinflammation"
292
  ],
293
  "true_pathways": {
294
+ "complement_cascade": 0.839,
295
+ "neuroinflammation": 0.805,
296
+ "amyloid_processing": 0.666,
297
+ "synaptic_signalling": 0.394,
298
+ "lipid_metabolism": 0.674
299
  },
300
+ "true_de_genes_count": 10,
301
+ "true_regulatory_network_size": 0,
302
  "confounders": {},
303
+ "n_true_cells": 7619,
304
  "technical": {
305
+ "ambient_rna_fraction": 0.04108598341080635,
306
+ "doublet_rate": 0.045763110874719674,
307
+ "dropout_rate": 0.07138299827651534,
308
+ "sample_quality": 0.9242864922806572,
309
  "library_complexity": 0.8,
310
  "capture_efficiency": 0.6
311
  },
 
330
  "followup_designed": true,
331
  "subagent_review_requested": false,
332
  "conclusion_reached": false,
333
+ "n_cells_sequenced": 4574,
334
+ "n_cells_after_filter": 3996,
335
  "n_clusters_found": "7",
336
  "n_de_genes_found": null,
337
  "n_markers_found": null
colab_train_llama32_remote.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal Colab entrypoint for Unsloth GRPO against a remote OpenEnv Space.
2
+
3
+ This keeps the repo's prompt formatting and action parsing logic, but builds
4
+ prompt states by interacting with a deployed OpenEnv Hugging Face Space instead
5
+ of the local in-process environment. That makes the Colab workflow match the
6
+ remote environment users actually want to train against.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import json
13
+ import random
14
+ from typing import Any, Dict, List, Optional, Sequence
15
+
16
+ from client import BioExperimentEnv
17
+ import training_script as base
18
+
19
+ DEFAULT_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
20
+ DEFAULT_OUTPUT_DIR = "artifacts/grpo-unsloth-llama32-3b-space"
21
+ DEFAULT_SPACE_REPO_ID = "Ev3Dev/hackathon"
22
+
23
+
24
+ def hf_space_repo_to_base_url(repo_id: str) -> str:
25
+ """Convert `owner/space-name` to the standard `hf.space` URL."""
26
+ owner, space_name = repo_id.split("/", 1)
27
+ normalized_owner = owner.strip().lower().replace("_", "-")
28
+ normalized_space = space_name.strip().lower().replace("_", "-")
29
+ return f"https://{normalized_owner}-{normalized_space}.hf.space"
30
+
31
+
32
+ def require_unsloth_base():
33
+ # Unsloth must be imported before trl / transformers / peft.
34
+ import unsloth # noqa: F401
35
+ import training_unsloth as unsloth_base
36
+
37
+ return unsloth_base
38
+
39
+
40
+ def build_argument_parser() -> argparse.ArgumentParser:
41
+ parser = argparse.ArgumentParser(
42
+ description="Train Unsloth Llama 3.2 3B on a remote OpenEnv Hugging Face Space."
43
+ )
44
+ parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
45
+ parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
46
+ parser.add_argument("--dataset-episodes", type=int, default=8)
47
+ parser.add_argument("--rollout-steps", type=int, default=6)
48
+ parser.add_argument(
49
+ "--collection-policy",
50
+ choices=["random", "heuristic"],
51
+ default="heuristic",
52
+ )
53
+ parser.add_argument("--base-url", default="")
54
+ parser.add_argument(
55
+ "--space-repo-id",
56
+ default=DEFAULT_SPACE_REPO_ID,
57
+ help="Hugging Face Space repo id, for example `Ev3Dev/hackathon`.",
58
+ )
59
+ parser.add_argument("--num-generations", type=int, default=2)
60
+ parser.add_argument("--max-completion-length", type=int, default=160)
61
+ parser.add_argument("--max-prompt-length", type=int, default=1280)
62
+ parser.add_argument("--max-seq-length", type=int, default=2048)
63
+ parser.add_argument("--per-device-train-batch-size", type=int, default=1)
64
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
65
+ parser.add_argument("--learning-rate", type=float, default=5e-6)
66
+ parser.add_argument("--num-train-epochs", type=float, default=1.0)
67
+ parser.add_argument("--logging-steps", type=int, default=1)
68
+ parser.add_argument("--save-steps", type=int, default=25)
69
+ parser.add_argument("--plot-metric-key", default=None)
70
+ parser.add_argument("--seed", type=int, default=42)
71
+ parser.add_argument("--dry-run", action="store_true")
72
+ parser.add_argument("--load-model-only", action="store_true")
73
+ parser.add_argument("--trust-remote-code", action="store_true")
74
+ parser.add_argument("--disable-4bit", action="store_true")
75
+ parser.add_argument("--lora-r", type=int, default=unsloth_defaults()["lora_r"])
76
+ parser.add_argument(
77
+ "--lora-alpha", type=int, default=unsloth_defaults()["lora_alpha"]
78
+ )
79
+ parser.add_argument(
80
+ "--lora-dropout", type=float, default=unsloth_defaults()["lora_dropout"]
81
+ )
82
+ return parser
83
+
84
+
85
+ def unsloth_defaults() -> Dict[str, float]:
86
+ return {
87
+ "lora_r": 16,
88
+ "lora_alpha": 16,
89
+ "lora_dropout": 0.0,
90
+ }
91
+
92
+
93
+ def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
94
+ args = build_argument_parser().parse_args(argv)
95
+ if not args.base_url:
96
+ args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
97
+ return args
98
+
99
+
100
+ def make_training_args(**overrides: Any) -> argparse.Namespace:
101
+ parser = build_argument_parser()
102
+ defaults = vars(parser.parse_args([]))
103
+ unknown = sorted(set(overrides) - set(defaults))
104
+ if unknown:
105
+ raise ValueError(f"Unknown training args: {', '.join(unknown)}")
106
+ defaults.update(overrides)
107
+ args = argparse.Namespace(**defaults)
108
+ if not getattr(args, "base_url", ""):
109
+ args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
110
+ return args
111
+
112
+
113
+ def build_remote_prompt_examples(args: argparse.Namespace) -> List[Dict[str, str]]:
114
+ """Collect prompt states directly from the remote OpenEnv server."""
115
+ rng = random.Random(args.seed)
116
+ examples: List[Dict[str, str]] = []
117
+
118
+ for _episode_idx in range(args.dataset_episodes):
119
+ with BioExperimentEnv(base_url=args.base_url) as env:
120
+ result = env.reset()
121
+ obs = result.observation
122
+ history_actions: List[base.ExperimentAction] = []
123
+
124
+ for step_idx in range(args.rollout_steps):
125
+ if obs.done:
126
+ break
127
+
128
+ next_action = base.build_experiment_action(
129
+ action_type=base.pick_action(
130
+ args.collection_policy,
131
+ step_idx,
132
+ [action.action_type for action in history_actions],
133
+ ),
134
+ discovered_markers=obs.discovered_markers,
135
+ candidate_mechanisms=obs.candidate_mechanisms,
136
+ conditions=obs.task.conditions,
137
+ )
138
+ examples.append(
139
+ {
140
+ "prompt": base.build_training_prompt(obs),
141
+ "history_actions": json.dumps(
142
+ [action.model_dump() for action in history_actions]
143
+ ),
144
+ "reference_action": base.action_completion_json(next_action),
145
+ "problem_statement": obs.task.problem_statement,
146
+ "episode_tag": f"remote-{rng.randrange(10**9):09d}",
147
+ }
148
+ )
149
+
150
+ history_actions.append(next_action)
151
+ result = env.step(next_action)
152
+ obs = result.observation
153
+ if result.done:
154
+ break
155
+
156
+ return examples
157
+
158
+
159
+ class RemoteSpaceReward:
160
+ """Reward function that replays each candidate against the remote Space."""
161
+
162
+ def __init__(
163
+ self,
164
+ *,
165
+ base_url: str,
166
+ invalid_action_penalty: float = base.INVALID_ACTION_PENALTY,
167
+ environment_error_penalty: float = base.ENVIRONMENT_ERROR_PENALTY,
168
+ ) -> None:
169
+ self.__name__ = "remote_space_reward"
170
+ self.base_url = base_url
171
+ self.invalid_action_penalty = invalid_action_penalty
172
+ self.environment_error_penalty = environment_error_penalty
173
+
174
+ def __call__(
175
+ self,
176
+ completions: List[Any],
177
+ history_actions: Optional[List[str]] = None,
178
+ **_: Any,
179
+ ) -> List[float]:
180
+ history_columns = base.normalise_column(history_actions, len(completions))
181
+ rewards: List[float] = []
182
+
183
+ for completion, current_history in zip(completions, history_columns):
184
+ action = base.parse_action_completion(base.completion_to_text(completion))
185
+ if action is None:
186
+ rewards.append(self.invalid_action_penalty)
187
+ continue
188
+
189
+ try:
190
+ rewards.append(self._score_remote(action, current_history))
191
+ except Exception:
192
+ rewards.append(self.environment_error_penalty)
193
+
194
+ return rewards
195
+
196
+ def _score_remote(
197
+ self,
198
+ action: base.ExperimentAction,
199
+ history_actions: Optional[str],
200
+ ) -> float:
201
+ with BioExperimentEnv(base_url=self.base_url) as env:
202
+ result = env.reset()
203
+ obs = result.observation
204
+
205
+ for previous_action in base.decode_history_actions(history_actions):
206
+ result = env.step(previous_action)
207
+ obs = result.observation
208
+ if result.done:
209
+ return float(result.reward or obs.reward or 0.0)
210
+
211
+ action = base.ensure_conclusion_claims(obs, action)
212
+ result = env.step(action)
213
+ if result.reward is not None:
214
+ return float(result.reward)
215
+ return float(result.observation.reward)
216
+
217
+
218
+ def run_dry_run_preview(
219
+ examples: Sequence[Dict[str, str]],
220
+ reward_fn: RemoteSpaceReward,
221
+ output_dir: str,
222
+ base_url: str,
223
+ ) -> None:
224
+ if not examples:
225
+ raise ValueError("No training prompts were generated for the dry run.")
226
+
227
+ sample = examples[0]
228
+ sample_reward = reward_fn(
229
+ completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
230
+ history_actions=[sample["history_actions"]],
231
+ )[0]
232
+
233
+ print(f"Built {len(examples)} remote prompt states.")
234
+ print(f"Remote OpenEnv Space: {base_url}")
235
+ print(f"Output directory: {output_dir}")
236
+ print(f"Sample reward for reference action: {sample_reward:+.3f}")
237
+ print("\nSample prompt:\n")
238
+ print(sample["prompt"])
239
+
240
+
241
+ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
242
+ random.seed(args.seed)
243
+ runtime = base.resolve_torch_runtime()
244
+ unsloth_base = require_unsloth_base()
245
+
246
+ if args.load_model_only:
247
+ tokenizer, model = unsloth_base.load_model_artifacts(
248
+ args.model_id,
249
+ trust_remote_code=args.trust_remote_code,
250
+ max_seq_length=args.max_seq_length,
251
+ load_in_4bit=not args.disable_4bit,
252
+ fast_inference=False,
253
+ prepare_for_inference=True,
254
+ )
255
+ return {
256
+ "args": args,
257
+ "runtime": runtime,
258
+ "tokenizer": tokenizer,
259
+ "model": model,
260
+ }
261
+
262
+ examples = build_remote_prompt_examples(args)
263
+ reward_fn = RemoteSpaceReward(base_url=args.base_url)
264
+
265
+ if args.dry_run:
266
+ run_dry_run_preview(examples, reward_fn, args.output_dir, args.base_url)
267
+ return {
268
+ "args": args,
269
+ "runtime": runtime,
270
+ "examples": examples,
271
+ "reward_fn": reward_fn,
272
+ }
273
+
274
+ from datasets import Dataset
275
+
276
+ FastLanguageModel = unsloth_base.patch_unsloth_grpo()
277
+ train_dataset = Dataset.from_list(examples)
278
+
279
+ tokenizer, model = unsloth_base.load_model_artifacts(
280
+ args.model_id,
281
+ trust_remote_code=args.trust_remote_code,
282
+ max_seq_length=args.max_seq_length,
283
+ load_in_4bit=not args.disable_4bit,
284
+ fast_inference=False,
285
+ )
286
+ model = unsloth_base.apply_lora_adapters(FastLanguageModel, model, args)
287
+
288
+ print(
289
+ f"Training runtime: device={runtime['device']} "
290
+ f"name={runtime['device_name']} "
291
+ f"dtype={runtime['dtype']} "
292
+ f"load_in_4bit={not args.disable_4bit}"
293
+ )
294
+ print(f"Remote OpenEnv Space: {args.base_url}")
295
+ print(f"Collected remote prompt states: {len(examples)}")
296
+
297
+ trainer = unsloth_base.build_unsloth_grpo_trainer(
298
+ model=model,
299
+ tokenizer=tokenizer,
300
+ reward_func=reward_fn,
301
+ train_dataset=train_dataset,
302
+ args=args,
303
+ runtime=runtime,
304
+ )
305
+ for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
306
+ if not hasattr(trainer, attr):
307
+ setattr(trainer, attr, None)
308
+
309
+ trainer.train()
310
+ trainer.save_model(args.output_dir)
311
+ tokenizer.save_pretrained(args.output_dir)
312
+
313
+ plot_paths = base.save_training_plots(
314
+ trainer.state.log_history,
315
+ args.output_dir,
316
+ metric_key=args.plot_metric_key,
317
+ )
318
+ print("Saved training plots:")
319
+ for plot_name, plot_path in plot_paths.items():
320
+ print(f" - {plot_name}: {plot_path}")
321
+
322
+ return {
323
+ "args": args,
324
+ "runtime": runtime,
325
+ "examples": examples,
326
+ "reward_fn": reward_fn,
327
+ "train_dataset": train_dataset,
328
+ "tokenizer": tokenizer,
329
+ "model": model,
330
+ "trainer": trainer,
331
+ "plot_paths": plot_paths,
332
+ }
333
+
334
+
335
+ def main() -> None:
336
+ run_training(parse_args())
337
+
338
+
339
+ if __name__ == "__main__":
340
+ main()
colab_train_unsloth.ipynb CHANGED
@@ -1,128 +1,347 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Unsloth training on Colab\n",
8
- "\n",
9
- "Minimal setup: clone repo install deps run GRPO training with Unsloth (Qwen3-4B, 4-bit + LoRA).\n",
10
- "\n",
11
- "**Runtime**: Enable a GPU (e.g. T4) in Colab: Runtime Change runtime type → GPU."
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": null,
17
- "metadata": {},
18
- "outputs": [],
19
- "source": [
20
- "# 1. Clone repo (set branch/tag if needed)\n",
21
- "REPO_URL = \"https://github.com/mhtruong1031/OpenENV-Hackathon.git\" # or your fork\n",
22
- "REPO_DIR = \"OpenENV-Hackathon\"\n",
23
- "\n",
24
- "!git clone --depth 1 {REPO_URL} {REPO_DIR}\n",
25
- "%cd {REPO_DIR}"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": null,
31
- "metadata": {},
32
- "outputs": [],
33
- "source": [
34
- "# 2. Install requirements: project + train extras + Unsloth (no-deps to keep trl>=0.29)\n",
35
- "!pip install -q -e \".[train]\"\n",
36
- "!pip install -q unsloth unsloth_zoo --no-deps\n",
37
- "\n",
38
- "# Optional: reward backends\n",
39
- "!pip install -q sentence-transformers gseapy 2>/dev/null || true"
40
- ]
41
- },
42
- {
43
- "cell_type": "code",
44
- "execution_count": null,
45
- "metadata": {},
46
- "outputs": [],
47
- "source": [
48
- "# 3. Unsloth must be imported before trl/transformers/peft\n",
49
- "import unsloth # noqa: F401\n",
50
- "import torch\n",
51
- "from pathlib import Path\n",
52
- "\n",
53
- "from training_unsloth import make_training_args, run_training\n",
54
- "\n",
55
- "print(\"CUDA:\", torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"\")\n",
56
- "Path(\"artifacts\").mkdir(exist_ok=True)"
57
- ]
58
- },
59
- {
60
- "cell_type": "code",
61
- "execution_count": null,
62
- "metadata": {},
63
- "outputs": [],
64
- "source": [
65
- "# 4. Training config (small run for Colab T4)\n",
66
- "args = make_training_args(\n",
67
- " model_id=\"Qwen/Qwen3-4B-Base\",\n",
68
- " output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n",
69
- " dataset_episodes=16,\n",
70
- " rollout_steps=10,\n",
71
- " collection_policy=\"heuristic\",\n",
72
- " reward_backend=\"local\",\n",
73
- " domain_randomise=True,\n",
74
- " num_generations=4,\n",
75
- " max_completion_length=160,\n",
76
- " max_prompt_length=1280,\n",
77
- " max_seq_length=2048,\n",
78
- " per_device_train_batch_size=2,\n",
79
- " gradient_accumulation_steps=4,\n",
80
- " learning_rate=5e-6,\n",
81
- " num_train_epochs=1.0,\n",
82
- " logging_steps=1,\n",
83
- " save_steps=25,\n",
84
- " trust_remote_code=True,\n",
85
- " dry_run=False,\n",
86
- " seed=42,\n",
87
- ")\n",
88
- "args"
89
- ]
90
- },
91
- {
92
- "cell_type": "code",
93
- "execution_count": null,
94
- "metadata": {},
95
- "outputs": [],
96
- "source": [
97
- "# 5. Run training\n",
98
- "result = run_training(args)\n",
99
- "print(\"Plots:\", result[\"plot_paths\"])"
100
- ]
101
- },
102
- {
103
- "cell_type": "code",
104
- "execution_count": null,
105
- "metadata": {},
106
- "outputs": [],
107
- "source": [
108
- "# 6. (Optional) Show loss curves\n",
109
- "from IPython.display import Image, display\n",
110
- "for name, path in result[\"plot_paths\"].items():\n",
111
- " display(Image(filename=path))"
112
- ]
113
- }
114
- ],
115
- "metadata": {
116
- "kernelspec": {
117
- "display_name": "Python 3",
118
- "language": "python",
119
- "name": "python3"
120
- },
121
- "language_info": {
122
- "name": "python",
123
- "version": "3.10.0"
124
- }
125
- },
126
- "nbformat": 4,
127
- "nbformat_minor": 4
128
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Minimal Unsloth GRPO on Colab with a remote OpenEnv Space\n",
8
+ "\n",
9
+ "This notebook is intentionally similar to the 2048 notebook pattern:\n",
10
+ "- training runs locally inside Colab\n",
11
+ "- the environment is accessed remotely through a Hugging Face Space\n",
12
+ "- the reward function is defined in notebook code by replaying actions against that remote env\n",
13
+ "- prompt / action / conclusion formatting mirrors the repo logic without importing the repo training script\n",
14
+ "\n",
15
+ "Default remote env: `Ev3Dev/hackathon`\n",
16
+ "\n",
17
+ "**Runtime**: Enable a GPU in Colab: Runtime -> Change runtime type -> GPU."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "# 1. Clone the repo for lightweight client / model definitions only\n",
27
+ "REPO_URL = \"https://github.com/mhtruong1031/OpenENV-Hackathon.git\" # or your fork\n",
28
+ "REPO_DIR = \"OpenENV-Hackathon\"\n",
29
+ "\n",
30
+ "!git clone --depth 1 {REPO_URL} {REPO_DIR}\n",
31
+ "%cd {REPO_DIR}"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "# 2. Install only the runtime pieces needed for notebook-side training\n",
41
+ "!pip install -q unsloth unsloth_zoo --no-deps\n",
42
+ "!pip install -q \"openenv-core[core]>=0.2.0\" \"pydantic>=2\" \"numpy>=1.24.0\" \"scipy>=1.10.0\" \"datasets>=4.6.1\" \"accelerate>=1.13.0\" \"peft>=0.15.0\" \"bitsandbytes>=0.45.0\" \"matplotlib>=3.8.0\"\n",
43
+ "!pip install -q \"transformers>=4.57.0\" \"trl>=0.29.0\" \"torchvision>=0.20.0\""
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# 3. Import repo reward helpers, but keep the environment remote\n",
53
+ "import inspect\n",
54
+ "import json\n",
55
+ "import random\n",
56
+ "import sys\n",
57
+ "from pathlib import Path\n",
58
+ "from typing import Any, Dict, List\n",
59
+ "\n",
60
+ "# Unsloth must be imported before trl / transformers / peft.\n",
61
+ "import unsloth # noqa: F401\n",
62
+ "import torch\n",
63
+ "from unsloth import FastLanguageModel, PatchFastRL\n",
64
+ "\n",
65
+ "sys.path.insert(0, str(Path.cwd()))\n",
66
+ "\n",
67
+ "from client import BioExperimentEnv\n",
68
+ "from models import ActionType, ExperimentAction\n",
69
+ "from training_script import (\n",
70
+ " INVALID_ACTION_PENALTY,\n",
71
+ " ENVIRONMENT_ERROR_PENALTY,\n",
72
+ " OpenEnvReward,\n",
73
+ " build_training_prompt,\n",
74
+ " build_experiment_action,\n",
75
+ " decode_history_actions,\n",
76
+ " pick_action,\n",
77
+ " save_training_plots,\n",
78
+ ")\n",
79
+ "\n",
80
+ "MAX_COMPLETION_TOKENS = 160\n",
81
+ "LORA_TARGET_MODULES = [\n",
82
+ " \"q_proj\",\n",
83
+ " \"k_proj\",\n",
84
+ " \"v_proj\",\n",
85
+ " \"o_proj\",\n",
86
+ " \"gate_proj\",\n",
87
+ " \"up_proj\",\n",
88
+ " \"down_proj\",\n",
89
+ "]\n",
90
+ "\n",
91
+ "\n",
92
+ "def hf_space_repo_to_base_url(repo_id: str) -> str:\n",
93
+ " owner, space_name = repo_id.split(\"/\", 1)\n",
94
+ " return f\"https://{owner.lower().replace('_', '-')}-{space_name.lower().replace('_', '-')}.hf.space\"\n",
95
+ "\n",
96
+ "\n",
97
+ "def build_remote_prompt_examples(\n",
98
+ " base_url: str,\n",
99
+ " dataset_episodes: int,\n",
100
+ " rollout_steps: int,\n",
101
+ " seed: int,\n",
102
+ ") -> List[Dict[str, str]]:\n",
103
+ " rng = random.Random(seed)\n",
104
+ " examples: List[Dict[str, str]] = []\n",
105
+ "\n",
106
+ " for _ in range(dataset_episodes):\n",
107
+ " with BioExperimentEnv(base_url=base_url) as env:\n",
108
+ " result = env.reset()\n",
109
+ " obs = result.observation\n",
110
+ " history_actions: List[ExperimentAction] = []\n",
111
+ "\n",
112
+ " for step_idx in range(rollout_steps):\n",
113
+ " if obs.done:\n",
114
+ " break\n",
115
+ "\n",
116
+ " next_action = build_experiment_action(\n",
117
+ " action_type=pick_action(\n",
118
+ " \"heuristic\",\n",
119
+ " step_idx,\n",
120
+ " [action.action_type for action in history_actions],\n",
121
+ " ),\n",
122
+ " discovered_markers=obs.discovered_markers,\n",
123
+ " candidate_mechanisms=obs.candidate_mechanisms,\n",
124
+ " conditions=obs.task.conditions,\n",
125
+ " )\n",
126
+ " examples.append(\n",
127
+ " {\n",
128
+ " \"prompt\": build_training_prompt(obs),\n",
129
+ " \"history_actions\": json.dumps(\n",
130
+ " [action.model_dump() for action in history_actions]\n",
131
+ " ),\n",
132
+ " \"reference_action\": json.dumps(next_action.model_dump()),\n",
133
+ " \"problem_statement\": obs.task.problem_statement,\n",
134
+ " \"episode_tag\": f\"remote-{rng.randrange(10**9):09d}\",\n",
135
+ " }\n",
136
+ " )\n",
137
+ "\n",
138
+ " history_actions.append(next_action)\n",
139
+ " result = env.step(next_action)\n",
140
+ " obs = result.observation\n",
141
+ " if result.done:\n",
142
+ " break\n",
143
+ "\n",
144
+ " return examples\n",
145
+ "\n",
146
+ "\n",
147
+ "def build_grpo_config(**overrides: Any):\n",
148
+ " from trl import GRPOConfig\n",
149
+ "\n",
150
+ " supported = set(inspect.signature(GRPOConfig.__init__).parameters)\n",
151
+ " config_kwargs = {\n",
152
+ " \"output_dir\": overrides[\"output_dir\"],\n",
153
+ " \"learning_rate\": overrides[\"learning_rate\"],\n",
154
+ " \"per_device_train_batch_size\": overrides[\"per_device_train_batch_size\"],\n",
155
+ " \"gradient_accumulation_steps\": overrides[\"gradient_accumulation_steps\"],\n",
156
+ " \"num_generations\": overrides[\"num_generations\"],\n",
157
+ " \"max_completion_length\": overrides[\"max_completion_length\"],\n",
158
+ " \"num_train_epochs\": overrides[\"num_train_epochs\"],\n",
159
+ " \"logging_steps\": overrides[\"logging_steps\"],\n",
160
+ " \"save_steps\": overrides[\"save_steps\"],\n",
161
+ " \"bf16\": overrides[\"bf16\"],\n",
162
+ " \"fp16\": overrides[\"fp16\"],\n",
163
+ " \"report_to\": \"none\",\n",
164
+ " \"remove_unused_columns\": False,\n",
165
+ " }\n",
166
+ " # Keep prompt truncation enabled. Leaving this as None can trigger\n",
167
+ " # an Unsloth rotary-cache shape mismatch on long GRPO prompts.\n",
168
+ " if \"max_prompt_length\" in supported:\n",
169
+ " config_kwargs[\"max_prompt_length\"] = overrides[\"max_prompt_length\"]\n",
170
+ " if (\n",
171
+ " \"max_length\" in supported\n",
172
+ " and \"max_prompt_length\" not in supported\n",
173
+ " and \"max_completion_length\" not in supported\n",
174
+ " ):\n",
175
+ " config_kwargs[\"max_length\"] = (\n",
176
+ " overrides[\"max_prompt_length\"] + overrides[\"max_completion_length\"]\n",
177
+ " )\n",
178
+ " return GRPOConfig(**{k: v for k, v in config_kwargs.items() if k in supported})\n",
179
+ "\n",
180
+ "\n",
181
+ "print(\"CUDA:\", torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"\")\n",
182
+ "Path(\"artifacts\").mkdir(exist_ok=True)"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "# 4. Config + collect prompt states from the remote Space\n",
192
+ "SPACE_REPO_ID = \"Ev3Dev/hackathon\"\n",
193
+ "SPACE_BASE_URL = hf_space_repo_to_base_url(SPACE_REPO_ID)\n",
194
+ "# If your Space has a custom domain, replace SPACE_BASE_URL manually.\n",
195
+ "\n",
196
+ "MODEL_ID = \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\"\n",
197
+ "OUTPUT_DIR = \"artifacts/grpo-unsloth-llama32-3b-remote-space\"\n",
198
+ "\n",
199
+ "DATASET_EPISODES = 8\n",
200
+ "ROLLOUT_STEPS = 6\n",
201
+ "NUM_GENERATIONS = 2\n",
202
+ "# Keep this modest for Unsloth GRPO stability on Colab.\n",
203
+ "MAX_PROMPT_LENGTH = 768\n",
204
+ "MAX_SEQ_LENGTH = 2048\n",
205
+ "PER_DEVICE_TRAIN_BATCH_SIZE = 1\n",
206
+ "GRADIENT_ACCUMULATION_STEPS = 4\n",
207
+ "LEARNING_RATE = 5e-6\n",
208
+ "NUM_TRAIN_EPOCHS = 1.0\n",
209
+ "LOGGING_STEPS = 1\n",
210
+ "SAVE_STEPS = 25\n",
211
+ "SEED = 42\n",
212
+ "LORA_R = 16\n",
213
+ "LORA_ALPHA = 16\n",
214
+ "LORA_DROPOUT = 0.0\n",
215
+ "\n",
216
+ "examples = build_remote_prompt_examples(\n",
217
+ " base_url=SPACE_BASE_URL,\n",
218
+ " dataset_episodes=DATASET_EPISODES,\n",
219
+ " rollout_steps=ROLLOUT_STEPS,\n",
220
+ " seed=SEED,\n",
221
+ ")\n",
222
+ "\n",
223
+ "reward_fn = OpenEnvReward(\n",
224
+ " reward_backend=\"remote\",\n",
225
+ " base_url=SPACE_BASE_URL,\n",
226
+ " invalid_action_penalty=INVALID_ACTION_PENALTY,\n",
227
+ " environment_error_penalty=ENVIRONMENT_ERROR_PENALTY,\n",
228
+ ")\n",
229
+ "\n",
230
+ "print(\"Remote env:\", SPACE_BASE_URL)\n",
231
+ "print(\"Prompt states:\", len(examples))\n",
232
+ "print(\"Sample prompt preview:\\n\")\n",
233
+ "print(examples[0][\"prompt\"][:2000])"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": [
242
+ "# 5. Local GRPO training in Colab, remote env for rewards\n",
243
+ "from datasets import Dataset\n",
244
+ "from trl import GRPOTrainer\n",
245
+ "\n",
246
+ "PatchFastRL(\"GRPO\", FastLanguageModel)\n",
247
+ "train_dataset = Dataset.from_list(examples)\n",
248
+ "\n",
249
+ "bf16 = bool(getattr(torch.cuda, \"is_bf16_supported\", lambda: False)()) if torch.cuda.is_available() else False\n",
250
+ "runtime_dtype = torch.bfloat16 if bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)\n",
251
+ "\n",
252
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
253
+ " model_name=MODEL_ID,\n",
254
+ " max_seq_length=MAX_SEQ_LENGTH,\n",
255
+ " dtype=runtime_dtype,\n",
256
+ " load_in_4bit=True,\n",
257
+ ")\n",
258
+ "if tokenizer.pad_token is None and tokenizer.eos_token is not None:\n",
259
+ " tokenizer.pad_token = tokenizer.eos_token\n",
260
+ "\n",
261
+ "model = FastLanguageModel.get_peft_model(\n",
262
+ " model,\n",
263
+ " r=LORA_R,\n",
264
+ " target_modules=LORA_TARGET_MODULES,\n",
265
+ " lora_alpha=LORA_ALPHA,\n",
266
+ " lora_dropout=LORA_DROPOUT,\n",
267
+ " bias=\"none\",\n",
268
+ " use_gradient_checkpointing=True,\n",
269
+ " random_state=SEED,\n",
270
+ ")\n",
271
+ "\n",
272
+ "training_args = build_grpo_config(\n",
273
+ " output_dir=OUTPUT_DIR,\n",
274
+ " learning_rate=LEARNING_RATE,\n",
275
+ " per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,\n",
276
+ " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
277
+ " num_generations=NUM_GENERATIONS,\n",
278
+ " max_completion_length=MAX_COMPLETION_TOKENS,\n",
279
+ " max_prompt_length=MAX_PROMPT_LENGTH,\n",
280
+ " num_train_epochs=NUM_TRAIN_EPOCHS,\n",
281
+ " logging_steps=LOGGING_STEPS,\n",
282
+ " save_steps=SAVE_STEPS,\n",
283
+ " bf16=bf16,\n",
284
+ " fp16=torch.cuda.is_available() and not bf16,\n",
285
+ ")\n",
286
+ "\n",
287
+ "trainer = GRPOTrainer(\n",
288
+ " model=model,\n",
289
+ " reward_funcs=[reward_fn],\n",
290
+ " args=training_args,\n",
291
+ " train_dataset=train_dataset,\n",
292
+ " processing_class=tokenizer,\n",
293
+ ")\n",
294
+ "\n",
295
+ "for attr in (\"image_token_id\", \"vision_start_token_id\", \"vision_end_token_id\"):\n",
296
+ " if not hasattr(trainer, attr):\n",
297
+ " setattr(trainer, attr, None)\n",
298
+ "\n",
299
+ "trainer.train()\n",
300
+ "trainer.save_model(OUTPUT_DIR)\n",
301
+ "tokenizer.save_pretrained(OUTPUT_DIR)\n",
302
+ "plot_paths = save_training_plots(trainer.state.log_history, OUTPUT_DIR)\n",
303
+ "\n",
304
+ "result = {\n",
305
+ " \"trainer\": trainer,\n",
306
+ " \"plot_paths\": plot_paths,\n",
307
+ " \"output_dir\": OUTPUT_DIR,\n",
308
+ "}\n",
309
+ "print(\"Saved to:\", OUTPUT_DIR)\n",
310
+ "print(\"Plots:\", plot_paths)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "# 6. (Optional) Show curves and sanity-check the repo reward wrapper\n",
320
+ "from IPython.display import Image, display\n",
321
+ "\n",
322
+ "sample_reward = reward_fn(\n",
323
+ " completions=[[{\"role\": \"assistant\", \"content\": examples[0][\"reference_action\"]}]],\n",
324
+ " history_actions=[examples[0][\"history_actions\"]],\n",
325
+ ")[0]\n",
326
+ "print(\"Sample reward for reference action:\", sample_reward)\n",
327
+ "\n",
328
+ "for name, path in (result.get(\"plot_paths\") or {}).items():\n",
329
+ " print(name, path)\n",
330
+ " display(Image(filename=path))"
331
+ ]
332
+ }
333
+ ],
334
+ "metadata": {
335
+ "kernelspec": {
336
+ "display_name": "Python 3",
337
+ "language": "python",
338
+ "name": "python3"
339
+ },
340
+ "language_info": {
341
+ "name": "python",
342
+ "version": "3.10.0"
343
+ }
344
+ },
345
+ "nbformat": 4,
346
+ "nbformat_minor": 4
347
+ }
pyproject.toml CHANGED
@@ -40,8 +40,6 @@ train = [
40
  "torch>=2.10.0",
41
  "torchvision>=0.20.0", # required by transformers for Qwen3.5 (image_utils)
42
  "transformers>=5.3.0", # 5.3+ required for Qwen3.5 (qwen3_5 model type)
43
- "llm-blender>=0.0.2", # required by trl GRPOTrainer judges
44
- "mergekit>=0.1.0", # required by trl GRPOTrainer/callbacks
45
  "trl>=0.29.0", # GRPOTrainer; 0.29+ compatible with transformers 5.3
46
  ]
47
 
 
40
  "torch>=2.10.0",
41
  "torchvision>=0.20.0", # required by transformers for Qwen3.5 (image_utils)
42
  "transformers>=5.3.0", # 5.3+ required for Qwen3.5 (qwen3_5 model type)
 
 
43
  "trl>=0.29.0", # GRPOTrainer; 0.29+ compatible with transformers 5.3
44
  ]
45
 
tests/test_colab_train_llama32_remote.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colab_train_llama32_remote import (
2
+ hf_space_repo_to_base_url,
3
+ make_training_args,
4
+ )
5
+
6
+
7
+ def test_hf_space_repo_to_base_url_formats_standard_hf_space_domain():
8
+ assert (
9
+ hf_space_repo_to_base_url("Ev3Dev/hackathon")
10
+ == "https://ev3dev-hackathon.hf.space"
11
+ )
12
+
13
+
14
+ def test_make_training_args_derives_base_url_when_missing():
15
+ args = make_training_args(space_repo_id="Ev3Dev/hackathon", base_url="")
16
+ assert args.base_url == "https://ev3dev-hackathon.hf.space"
17
+ assert args.model_id == "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
unsloth_2048.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py ADDED
File without changes
unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py ADDED
File without changes
unsloth_compiled_cache/UnslothCPOTrainer.py CHANGED
@@ -28,7 +28,7 @@ import torch.nn as nn
28
  from torch.nn import functional as F
29
  from unsloth_zoo.temporary_patches.common import torch_compile
30
  from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
- from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
32
 
33
 
34
  import os
 
28
  from torch.nn import functional as F
29
  from unsloth_zoo.temporary_patches.common import torch_compile
30
  from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, wandb, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, wandb, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
32
 
33
 
34
  import os
unsloth_compiled_cache/UnslothDPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothGRPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothKTOTrainer.py CHANGED
@@ -28,7 +28,7 @@ import torch.nn as nn
28
  from torch.nn import functional as F
29
  from unsloth_zoo.temporary_patches.common import torch_compile
30
  from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
- from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, autocast, concatenate_datasets, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, TrainingArguments, Union, autocast, concatenate_datasets, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch, F, nn, np, os, selective_log_softmax, torch)
32
 
33
 
34
  import os
 
28
  from torch.nn import functional as F
29
  from unsloth_zoo.temporary_patches.common import torch_compile
30
  from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, autocast, concatenate_datasets, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, wandb, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, TrainingArguments, Union, autocast, concatenate_datasets, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, wandb, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch, F, nn, np, os, selective_log_softmax, torch)
32
 
33
 
34
  import os
unsloth_compiled_cache/UnslothNashMDTrainer.py ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothNashMDConfig(NashMDConfig):
325
+ """
326
+
327
+ Configuration class for the [`NashMDTrainer`].
328
+
329
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
330
+
331
+ Parameters:
332
+ mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
333
+ Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
334
+ mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
335
+ epochs.
336
+
337
+ """
338
+ vllm_sampling_params: Optional[Any] = field(
339
+ default = None,
340
+ metadata = {'help': 'vLLM SamplingParams'},
341
+ )
342
+ unsloth_num_chunks : Optional[int] = field(
343
+ default = -1,
344
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
345
+ )
346
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
347
+ default = None,
348
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
349
+ )
350
+ unsloth_grpo_mini_batch : Optional[int] = field(
351
+ default = None,
352
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
353
+ )
354
+ max_seq_length : Optional[int] = field(
355
+ default = None,
356
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
357
+ )
358
+ def __init__(
359
+ self,
360
+ output_dir = None,
361
+ per_device_train_batch_size = 4,
362
+ num_train_epochs = 3.0,
363
+ max_steps = -1,
364
+ learning_rate = 5e-05,
365
+ lr_scheduler_type = 'linear',
366
+ lr_scheduler_kwargs = None,
367
+ warmup_steps = 0.1,
368
+ optim = 'adamw_8bit',
369
+ optim_args = None,
370
+ weight_decay = 0.01,
371
+ adam_beta1 = 0.9,
372
+ adam_beta2 = 0.999,
373
+ adam_epsilon = 1e-08,
374
+ optim_target_modules = None,
375
+ gradient_accumulation_steps = 2,
376
+ average_tokens_across_devices = True,
377
+ max_grad_norm = 1.0,
378
+ label_smoothing_factor = 0.0,
379
+ bf16 = False,
380
+ fp16 = False,
381
+ bf16_full_eval = False,
382
+ fp16_full_eval = False,
383
+ tf32 = None,
384
+ gradient_checkpointing = True,
385
+ gradient_checkpointing_kwargs = None,
386
+ torch_compile = False,
387
+ torch_compile_backend = None,
388
+ torch_compile_mode = None,
389
+ use_liger_kernel = False,
390
+ liger_kernel_config = None,
391
+ use_cache = False,
392
+ neftune_noise_alpha = None,
393
+ torch_empty_cache_steps = 250,
394
+ auto_find_batch_size = False,
395
+ logging_strategy = 'steps',
396
+ logging_steps = 1,
397
+ logging_first_step = False,
398
+ log_on_each_node = True,
399
+ logging_nan_inf_filter = False,
400
+ include_num_input_tokens_seen = False,
401
+ log_level = 'passive',
402
+ log_level_replica = 'warning',
403
+ disable_tqdm = None,
404
+ report_to = 'none',
405
+ run_name = None,
406
+ project = 'huggingface',
407
+ trackio_space_id = 'trackio',
408
+ eval_strategy = 'no',
409
+ eval_steps = None,
410
+ eval_delay = 0,
411
+ per_device_eval_batch_size = 4,
412
+ prediction_loss_only = False,
413
+ eval_on_start = False,
414
+ eval_do_concat_batches = True,
415
+ eval_use_gather_object = False,
416
+ eval_accumulation_steps = 2,
417
+ batch_eval_metrics = False,
418
+ save_only_model = False,
419
+ save_strategy = 'steps',
420
+ save_steps = 500,
421
+ save_on_each_node = False,
422
+ save_total_limit = None,
423
+ enable_jit_checkpoint = False,
424
+ push_to_hub = False,
425
+ hub_token = None,
426
+ hub_private_repo = None,
427
+ hub_model_id = None,
428
+ hub_strategy = 'every_save',
429
+ hub_always_push = False,
430
+ hub_revision = None,
431
+ load_best_model_at_end = False,
432
+ metric_for_best_model = None,
433
+ greater_is_better = None,
434
+ ignore_data_skip = False,
435
+ restore_callback_states_from_checkpoint = False,
436
+ full_determinism = False,
437
+ seed = 3407,
438
+ data_seed = 3407,
439
+ use_cpu = False,
440
+ accelerator_config = None,
441
+ parallelism_config = None,
442
+ dataloader_drop_last = False,
443
+ dataloader_num_workers = 0,
444
+ dataloader_pin_memory = True,
445
+ dataloader_persistent_workers = False,
446
+ dataloader_prefetch_factor = None,
447
+ remove_unused_columns = True,
448
+ label_names = None,
449
+ train_sampling_strategy = 'random',
450
+ length_column_name = 'length',
451
+ ddp_find_unused_parameters = None,
452
+ ddp_bucket_cap_mb = None,
453
+ ddp_broadcast_buffers = None,
454
+ ddp_backend = None,
455
+ ddp_timeout = 1800,
456
+ fsdp = None,
457
+ fsdp_config = None,
458
+ deepspeed = None,
459
+ debug = '',
460
+ skip_memory_metrics = True,
461
+ do_train = False,
462
+ do_eval = False,
463
+ do_predict = False,
464
+ resume_from_checkpoint = None,
465
+ warmup_ratio = None,
466
+ logging_dir = None,
467
+ local_rank = -1,
468
+ reward_model_path = None,
469
+ judge = None,
470
+ max_new_tokens = 64,
471
+ max_length = 512,
472
+ temperature = 0.9,
473
+ top_p = 1.0,
474
+ top_k = None,
475
+ min_p = None,
476
+ repetition_penalty = 1.0,
477
+ generation_kwargs = {},
478
+ use_transformers_paged = False,
479
+ cache_implementation = None,
480
+ missing_eos_penalty = None,
481
+ loss_type = 'sigmoid',
482
+ disable_dropout = True,
483
+ use_vllm = False,
484
+ vllm_model_impl = 'vllm',
485
+ vllm_guided_decoding_regex = None,
486
+ vllm_gpu_memory_utilization = 0.55,
487
+ vllm_mode = 'colocate',
488
+ vllm_server_base_url = None,
489
+ vllm_server_host = '0.0.0.0',
490
+ vllm_server_port = 8000,
491
+ vllm_server_timeout = 240.0,
492
+ vllm_tensor_parallel_size = 1,
493
+ ds3_gather_for_generation = True,
494
+ model_init_kwargs = None,
495
+ reward_weights = None,
496
+ dataset_num_proc = None,
497
+ gpu_memory_utilization = None,
498
+ vllm_sampling_params = None,
499
+ unsloth_num_chunks = -1,
500
+ unsloth_logit_chunk_multiplier = None,
501
+ unsloth_grpo_mini_batch = None,
502
+ max_seq_length = None,
503
+ **kwargs,
504
+ ):
505
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
506
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
507
+ if num_train_epochs is None:
508
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
509
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
510
+ output_dir = 'unsloth_training_checkpoints'
511
+ save_strategy = 'no'
512
+ import multiprocessing as _mp
513
+ if _mp.get_start_method() != 'fork':
514
+ dataset_num_proc = None
515
+ elif dataset_num_proc is None:
516
+ import psutil
517
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
518
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
519
+ if memory_gb_left <= 2: dataset_num_proc = 1
520
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
521
+ if temperature <= 0:
522
+ raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
523
+ elif temperature >= 10:
524
+ raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
525
+
526
+
527
+ super().__init__(
528
+ output_dir = output_dir,
529
+ per_device_train_batch_size = per_device_train_batch_size,
530
+ num_train_epochs = num_train_epochs,
531
+ max_steps = max_steps,
532
+ learning_rate = learning_rate,
533
+ lr_scheduler_type = lr_scheduler_type,
534
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
535
+ warmup_steps = warmup_steps,
536
+ optim = optim,
537
+ optim_args = optim_args,
538
+ weight_decay = weight_decay,
539
+ adam_beta1 = adam_beta1,
540
+ adam_beta2 = adam_beta2,
541
+ adam_epsilon = adam_epsilon,
542
+ optim_target_modules = optim_target_modules,
543
+ gradient_accumulation_steps = gradient_accumulation_steps,
544
+ average_tokens_across_devices = average_tokens_across_devices,
545
+ max_grad_norm = max_grad_norm,
546
+ label_smoothing_factor = label_smoothing_factor,
547
+ bf16 = bf16,
548
+ fp16 = fp16,
549
+ bf16_full_eval = bf16_full_eval,
550
+ fp16_full_eval = fp16_full_eval,
551
+ tf32 = tf32,
552
+ gradient_checkpointing = gradient_checkpointing,
553
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
554
+ torch_compile = torch_compile,
555
+ torch_compile_backend = torch_compile_backend,
556
+ torch_compile_mode = torch_compile_mode,
557
+ use_liger_kernel = use_liger_kernel,
558
+ liger_kernel_config = liger_kernel_config,
559
+ use_cache = use_cache,
560
+ neftune_noise_alpha = neftune_noise_alpha,
561
+ torch_empty_cache_steps = torch_empty_cache_steps,
562
+ auto_find_batch_size = auto_find_batch_size,
563
+ logging_strategy = logging_strategy,
564
+ logging_steps = logging_steps,
565
+ logging_first_step = logging_first_step,
566
+ log_on_each_node = log_on_each_node,
567
+ logging_nan_inf_filter = logging_nan_inf_filter,
568
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
569
+ log_level = log_level,
570
+ log_level_replica = log_level_replica,
571
+ disable_tqdm = disable_tqdm,
572
+ report_to = report_to,
573
+ run_name = run_name,
574
+ project = project,
575
+ trackio_space_id = trackio_space_id,
576
+ eval_strategy = eval_strategy,
577
+ eval_steps = eval_steps,
578
+ eval_delay = eval_delay,
579
+ per_device_eval_batch_size = per_device_eval_batch_size,
580
+ prediction_loss_only = prediction_loss_only,
581
+ eval_on_start = eval_on_start,
582
+ eval_do_concat_batches = eval_do_concat_batches,
583
+ eval_use_gather_object = eval_use_gather_object,
584
+ eval_accumulation_steps = eval_accumulation_steps,
585
+ batch_eval_metrics = batch_eval_metrics,
586
+ save_only_model = save_only_model,
587
+ save_strategy = save_strategy,
588
+ save_steps = save_steps,
589
+ save_on_each_node = save_on_each_node,
590
+ save_total_limit = save_total_limit,
591
+ enable_jit_checkpoint = enable_jit_checkpoint,
592
+ push_to_hub = push_to_hub,
593
+ hub_token = hub_token,
594
+ hub_private_repo = hub_private_repo,
595
+ hub_model_id = hub_model_id,
596
+ hub_strategy = hub_strategy,
597
+ hub_always_push = hub_always_push,
598
+ hub_revision = hub_revision,
599
+ load_best_model_at_end = load_best_model_at_end,
600
+ metric_for_best_model = metric_for_best_model,
601
+ greater_is_better = greater_is_better,
602
+ ignore_data_skip = ignore_data_skip,
603
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
604
+ full_determinism = full_determinism,
605
+ seed = seed,
606
+ data_seed = data_seed,
607
+ use_cpu = use_cpu,
608
+ accelerator_config = accelerator_config,
609
+ parallelism_config = parallelism_config,
610
+ dataloader_drop_last = dataloader_drop_last,
611
+ dataloader_num_workers = dataloader_num_workers,
612
+ dataloader_pin_memory = dataloader_pin_memory,
613
+ dataloader_persistent_workers = dataloader_persistent_workers,
614
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
615
+ remove_unused_columns = remove_unused_columns,
616
+ label_names = label_names,
617
+ train_sampling_strategy = train_sampling_strategy,
618
+ length_column_name = length_column_name,
619
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
620
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
621
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
622
+ ddp_backend = ddp_backend,
623
+ ddp_timeout = ddp_timeout,
624
+ fsdp = fsdp,
625
+ fsdp_config = fsdp_config,
626
+ deepspeed = deepspeed,
627
+ debug = debug,
628
+ skip_memory_metrics = skip_memory_metrics,
629
+ do_train = do_train,
630
+ do_eval = do_eval,
631
+ do_predict = do_predict,
632
+ resume_from_checkpoint = resume_from_checkpoint,
633
+ warmup_ratio = warmup_ratio,
634
+ logging_dir = logging_dir,
635
+ local_rank = local_rank,
636
+ reward_model_path = reward_model_path,
637
+ judge = judge,
638
+ max_new_tokens = max_new_tokens,
639
+ max_length = max_length,
640
+ temperature = temperature,
641
+ top_p = top_p,
642
+ top_k = top_k,
643
+ min_p = min_p,
644
+ repetition_penalty = repetition_penalty,
645
+ generation_kwargs = generation_kwargs,
646
+ use_transformers_paged = use_transformers_paged,
647
+ cache_implementation = cache_implementation,
648
+ missing_eos_penalty = missing_eos_penalty,
649
+ loss_type = loss_type,
650
+ disable_dropout = disable_dropout,
651
+ use_vllm = use_vllm,
652
+ vllm_model_impl = vllm_model_impl,
653
+ vllm_guided_decoding_regex = vllm_guided_decoding_regex,
654
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
655
+ vllm_mode = vllm_mode,
656
+ vllm_server_base_url = vllm_server_base_url,
657
+ vllm_server_host = vllm_server_host,
658
+ vllm_server_port = vllm_server_port,
659
+ vllm_server_timeout = vllm_server_timeout,
660
+ vllm_tensor_parallel_size = vllm_tensor_parallel_size,
661
+ ds3_gather_for_generation = ds3_gather_for_generation,
662
+ model_init_kwargs = model_init_kwargs,
663
+ reward_weights = reward_weights,
664
+ dataset_num_proc = dataset_num_proc,
665
+ gpu_memory_utilization = gpu_memory_utilization,**kwargs)
666
+ self.vllm_sampling_params = vllm_sampling_params
667
+ self.unsloth_num_chunks = unsloth_num_chunks
668
+ if unsloth_grpo_mini_batch is not None:
669
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
670
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
671
+ else:
672
+ raise ValueError(
673
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
674
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
675
+ )
676
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
677
+ self.max_seq_length = max_seq_length
678
+
679
+ pass
680
+
681
+ class _UnslothNashMDTrainer(OnlineDPOTrainer):
682
+ """"""
683
+
684
+ _tag_names = ["trl", "nash-md"]
685
+ _name = "Nash-MD"
686
+ _paper = {
687
+ "title": "Nash Learning from Human Feedback",
688
+ "id": "2312.00886",
689
+ # docstyle-ignore
690
+ "citation": textwrap.dedent("""\
691
+ @inproceedings{munos2024nash,
692
+ title = {{Nash Learning from Human Feedback}},
693
+ author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
694
+ year = 2024,
695
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
696
+ publisher = {OpenReview.net},
697
+ url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
698
+ }"""),
699
+ }
700
+
701
+ def __init__(
702
+ self,
703
+ model: Union[PreTrainedModel, nn.Module] = None,
704
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
705
+ reward_funcs: Union[PreTrainedModel, nn.Module, None] = None,
706
+ judge: Optional[BasePairwiseJudge] = None,
707
+ args: Optional[NashMDConfig] = None,
708
+ data_collator: Optional[Callable] = None,
709
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
710
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
711
+ processing_class: Optional[
712
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
713
+ ] = None,
714
+ peft_config: Optional[dict] = None,
715
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
716
+ callbacks: Optional[list[TrainerCallback]] = None,
717
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
718
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
719
+ # Deprecated parameters
720
+ reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
721
+ ) -> None:
722
+ super().__init__(
723
+ model=model,
724
+ ref_model=ref_model,
725
+ reward_funcs=reward_funcs,
726
+ judge=judge,
727
+ args=args,
728
+ data_collator=data_collator,
729
+ train_dataset=train_dataset,
730
+ eval_dataset=eval_dataset,
731
+ processing_class=processing_class,
732
+ reward_processing_classes=processing_class,
733
+ peft_config=peft_config,
734
+ compute_metrics=compute_metrics,
735
+ callbacks=callbacks,
736
+ optimizers=optimizers,
737
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
738
+ reward_model=reward_model,
739
+ )
740
+
741
+ self._mixture_coef = self.args.mixture_coef
742
+
743
+ # Overwrite the stats dictionary to include NashMD specific statistics
744
+ self.stats = {
745
+ # Remove "non_score_reward", "rlhf_reward", "scores_margin"
746
+ # Add "mixture_coef"
747
+ "loss/kl": [],
748
+ "objective/entropy": [],
749
+ "loss/score": [],
750
+ "rewards/probabilities": [],
751
+ "rewards/accuracies": [],
752
+ "rewards/margins": [],
753
+ "logps/chosen": [],
754
+ "logps/rejected": [],
755
+ "val/model_contain_eos_token": [],
756
+ "val/ref_contain_eos_token": [],
757
+ "beta": [],
758
+ "mixture_coef": [],
759
+ }
760
+ if self.reward_funcs is not None:
761
+ if len(self.reward_funcs) != 1:
762
+ raise ValueError("NashMDTrainer only supports one reward function/model.")
763
+ self.reward_funcs = self.reward_funcs[0]
764
+ self.stats["rewards/chosen"] = []
765
+ self.stats["rewards/rejected"] = []
766
+
767
+ @property
768
+ def mixture_coef(self):
769
+ if isinstance(self._mixture_coef, list):
770
+ epoch = self.state.epoch
771
+ return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
772
+ else:
773
+ return self._mixture_coef
774
+
775
+ def _generate_completions(self, model, prompts):
776
+ # Generate completions from the policy model.
777
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
778
+ model_output = unwrapped_policy_for_gen_ctx.generate(
779
+ input_ids=prompts["input_ids"],
780
+ attention_mask=prompts["attention_mask"],
781
+ generation_config=self.generation_config,
782
+ )
783
+
784
+ # Get the DDP/FSDP unwrapped version of the main model.
785
+ # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
786
+ policy_model_for_gmw = self.accelerator.unwrap_model(model)
787
+
788
+ # Determine the correct reference model for GeometricMixtureWrapper.
789
+ # This also needs to be DDP/FSDP unwrapped.
790
+ ref_model_for_gmw: torch.nn.Module
791
+ if self.ref_model is None:
792
+ # No explicit ref_model is provided.
793
+ # Use the base of the main `model` if it's a PEFT model.
794
+ # policy_model_for_gmw is already DDP-unwrapped.
795
+ if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
796
+ ref_model_for_gmw = policy_model_for_gmw.get_base_model()
797
+ else:
798
+ # Not a PEFT model (or PEFT not available), or already a base model.
799
+ # Use the DDP-unwrapped policy model itself as the reference.
800
+ ref_model_for_gmw = policy_model_for_gmw
801
+ else:
802
+ # An explicit ref_model is provided. Unwrap it for DDP/FSDP.
803
+ ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
804
+
805
+ # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
806
+ with torch.no_grad(): # Ensure no_grad context for mixture model generation
807
+ mixture_model = GeometricMixtureWrapper(
808
+ model=policy_model_for_gmw,
809
+ ref_model=ref_model_for_gmw,
810
+ generation_config=self.generation_config,
811
+ mixture_coef=self.mixture_coef,
812
+ device=self.accelerator.device,
813
+ )
814
+
815
+ mixture_output = mixture_model.generate(
816
+ input_ids=prompts["input_ids"],
817
+ attention_mask=prompts["attention_mask"],
818
+ generation_config=self.generation_config,
819
+ )
820
+
821
+ return model_output, mixture_output
822
+
823
+ def _process_completions(self, model_output, mixture_output, prompts):
824
+ context_length = prompts["input_ids"].shape[1]
825
+
826
+ # Process model completions
827
+ model_completion_ids = model_output[:, context_length:]
828
+ model_completion_ids, model_completion_mask = truncate_right(
829
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
830
+ )
831
+ model_data = {
832
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
833
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
834
+ "raw": prompts["raw"],
835
+ }
836
+
837
+ # Process reference model completions
838
+ mixture_completion_ids = mixture_output[:, context_length:]
839
+ mixture_completion_ids, mixture_completion_mask = truncate_right(
840
+ mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
841
+ )
842
+ mixture_data = {
843
+ "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
844
+ "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
845
+ "raw": prompts["raw"],
846
+ }
847
+
848
+ return model_data, mixture_data
849
+
850
+ def _compute_rewards(self, model_data, mixture_data, context_length):
851
+ with torch.no_grad():
852
+ _, model_scores, _ = get_reward(
853
+ self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length
854
+ )
855
+ _, mixture_scores, _ = get_reward(
856
+ self.reward_funcs, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
857
+ )
858
+
859
+ # Apply EOS penalty if needed
860
+ if self.args.missing_eos_penalty is not None:
861
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
862
+ mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
863
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
864
+ mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
865
+
866
+ return model_scores, mixture_scores
867
+
868
+ def _compute_judge(self, model_data, mixture_data, context_length):
869
+ prompts = model_data["raw"]
870
+ model_data_completions = self.processing_class.batch_decode(
871
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
872
+ )
873
+ model_data_completions = [completion.strip() for completion in model_data_completions]
874
+
875
+ mixture_data_completions = self.processing_class.batch_decode(
876
+ mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
877
+ )
878
+ mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
879
+ if is_conversational({"prompt": prompts[0]}):
880
+ model_data_completions = [
881
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
882
+ ]
883
+ environment = jinja2.Environment()
884
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
885
+ prompts = [template.render(messages=message) for message in prompts]
886
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
887
+
888
+ mixture_data_completions = [
889
+ [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
890
+ ]
891
+ mixture_data_completions = [
892
+ template.render(messages=completion) for completion in mixture_data_completions
893
+ ]
894
+
895
+ probability = self.judge.judge(
896
+ prompts,
897
+ list(zip(model_data_completions, mixture_data_completions)),
898
+ return_scores=True,
899
+ )
900
+ return torch.tensor(probability, device=model_data["input_ids"].device)
901
+
902
+ def _compute_logprobs(self, model, model_data, context_length):
903
+ def compute_logprobs_for_data(m, data):
904
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
905
+ logits = output.logits[:, context_length - 1 : -1]
906
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
907
+ return token_logprobs
908
+
909
+ # Compute logprobs for model completions under the model
910
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
911
+
912
+ # Compute logprobs of model completions under the reference model
913
+ with torch.no_grad():
914
+ if self.ref_model is None:
915
+ with model.disable_adapter():
916
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
917
+ else:
918
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
919
+
920
+ # Mask padding tokens
921
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
922
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
923
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
924
+
925
+ return (model_logprobs_model_data, ref_logprobs_model_data)
926
+
927
+ def _compute_losses(
928
+ self,
929
+ model_logprobs_model_data,
930
+ ref_logprobs_model_data,
931
+ probability,
932
+ ):
933
+ # reinforce score where 0.5 is a control variate
934
+ score = (probability - 0.5) * model_logprobs_model_data.sum(1)
935
+
936
+ # kl divergence via reinforce
937
+ with torch.no_grad():
938
+ log_ratio = model_logprobs_model_data - ref_logprobs_model_data
939
+ kl_div_log = log_ratio.sum(1)
940
+ kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
941
+
942
+ # final loss
943
+ loss = self.beta * kl_div_loss - score
944
+
945
+ return loss.mean(), score, kl_div_log
946
+
947
+ def _log_statistics(
948
+ self,
949
+ model_data,
950
+ mixture_data,
951
+ model_logprobs_model_data,
952
+ ref_logprobs_model_data,
953
+ probability,
954
+ score,
955
+ kl_div,
956
+ context_length,
957
+ model_scores=None,
958
+ mixture_scores=None,
959
+ ):
960
+ # Helper function to gather and compute mean
961
+ def gather_mean(tensor):
962
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
963
+
964
+ # Log score
965
+ self.stats["loss/score"].append(gather_mean(score))
966
+ # Log KL divergence
967
+ self.stats["loss/kl"].append(gather_mean(kl_div))
968
+
969
+ # Log logprobs
970
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
971
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
972
+
973
+ self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
974
+ self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
975
+
976
+ # Log rewards
977
+ if self.reward_funcs is not None:
978
+ self.stats["rewards/chosen"].append(gather_mean(model_scores))
979
+ self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
980
+
981
+ # Log probabilities
982
+ self.stats["rewards/probabilities"].append(gather_mean(probability))
983
+
984
+ # Calculate entropy for model data
985
+ entropy_model_data = -model_logprobs_model_data.sum(1)
986
+ self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
987
+
988
+ # Calculate margins
989
+ margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
990
+ self.stats["rewards/margins"].append(gather_mean(margin))
991
+
992
+ # Calculate accuracy
993
+ accuracy = (margin > 0).float()
994
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy))
995
+
996
+ # Log EOS token statistics
997
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
998
+ mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
999
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
1000
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
1001
+
1002
+ # Log beta and mixture coef
1003
+ self.stats["beta"].append(self.beta)
1004
+ self.stats["mixture_coef"].append(self.mixture_coef)
1005
+
1006
+ def training_step(
1007
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
1008
+ ) -> torch.Tensor:
1009
+ model.train()
1010
+
1011
+ # Apply chat template and tokenize the input
1012
+ batch_size = len(next(iter(inputs.values())))
1013
+ prompts = inputs["prompt"]
1014
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
1015
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
1016
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
1017
+ inputs = self.data_collator(inputs)
1018
+
1019
+ # need the prompt_ only
1020
+ inputs = self._prepare_inputs(inputs)
1021
+ context_length = inputs["prompt_input_ids"].shape[1]
1022
+ prompts = {
1023
+ "input_ids": inputs["prompt_input_ids"],
1024
+ "attention_mask": inputs["prompt_attention_mask"],
1025
+ "raw": prompts,
1026
+ }
1027
+ del inputs
1028
+
1029
+ # Sample completions from both the model and the reference model
1030
+ model_output, mixture_output = self._generate_completions(model, prompts)
1031
+
1032
+ # Process model completions
1033
+ model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
1034
+
1035
+ # Compute rewards
1036
+ if self.reward_funcs is not None:
1037
+ model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
1038
+ # probability of the model data vs the mixture data
1039
+ probability = F.sigmoid(model_scores - mixture_scores)
1040
+ else:
1041
+ model_scores, mixture_scores = None, None
1042
+ probability = self._compute_judge(model_data, mixture_data, context_length)
1043
+
1044
+ # Compute logprobs
1045
+ model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
1046
+
1047
+ # Compute loss
1048
+ loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
1049
+
1050
+ # Log everything
1051
+ self._log_statistics(
1052
+ model_data,
1053
+ mixture_data,
1054
+ model_logprobs_model_data.detach(),
1055
+ ref_logprobs_model_data,
1056
+ probability,
1057
+ score.detach(),
1058
+ kl_div.detach(),
1059
+ context_length,
1060
+ model_scores,
1061
+ mixture_scores,
1062
+ )
1063
+
1064
+ if (
1065
+ self.args.torch_empty_cache_steps is not None
1066
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
1067
+ ):
1068
+ empty_cache()
1069
+
1070
+ kwargs = {}
1071
+ # For LOMO optimizers you need to explicitly use the learning rate
1072
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
1073
+ kwargs["learning_rate"] = self._get_learning_rate()
1074
+
1075
+ if self.args.n_gpu > 1:
1076
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
1077
+
1078
+ self.accelerator.backward(loss, **kwargs)
1079
+
1080
+ return loss.detach() / self.args.gradient_accumulation_steps
1081
+ class UnslothNashMDTrainer(_UnslothNashMDTrainer):
1082
+ """
1083
+
1084
+ Trainer for the Nash-MD method.
1085
+
1086
+ It is implemented as a subclass of [`OnlineDPOTrainer`].
1087
+
1088
+ Args:
1089
+ model ([`~transformers.PreTrainedModel`]):
1090
+ The model to train, preferably an `AutoModelForCausalLM`.
1091
+ ref_model ([`PreTrainedModelWrapper`]):
1092
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
1093
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
1094
+ architecture as the model to be optimized.
1095
+ reward_funcs ([`~transformers.PreTrainedModel`]):
1096
+ The reward model to score completions with, preferably an
1097
+ [`~transformers.AutoModelForSequenceClassification`].
1098
+ judge ([`BasePairwiseJudge`]):
1099
+ The judge to use for pairwise comparison of model completions.
1100
+ args ([`NashMDConfig`]):
1101
+ The NashMD config arguments to use for training.
1102
+ data_collator ([`~transformers.DataCollator`]):
1103
+ The data collator to use for training. If None is specified, the default data collator
1104
+ ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
1105
+ sequences in the batch, given a dataset of paired sequences.
1106
+ train_dataset ([`~datasets.Dataset`]):
1107
+ The dataset to use for training.
1108
+ eval_dataset ([`~datasets.Dataset`]):
1109
+ The dataset to use for evaluation.
1110
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
1111
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1112
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1113
+ reuse the fine-tuned model.
1114
+ peft_config (`dict`):
1115
+ The peft config to use for training.
1116
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1117
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1118
+ metric values.
1119
+ callbacks (`list[transformers.TrainerCallback]`):
1120
+ The callbacks to use for training.
1121
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1122
+ The optimizer and scheduler to use for training.
1123
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1124
+ The function to use to preprocess the logits before computing the metrics.
1125
+
1126
+ reward_model:
1127
+
1128
+ <Deprecated version="0.22.0">
1129
+
1130
+ This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
1131
+
1132
+ </Deprecated>
1133
+
1134
+ """
1135
+ def __init__(
1136
+ self,
1137
+ model = None,
1138
+ ref_model = None,
1139
+ reward_funcs = None,
1140
+ judge = None,
1141
+ args = None,
1142
+ data_collator = None,
1143
+ train_dataset = None,
1144
+ eval_dataset = None,
1145
+ processing_class = None,
1146
+ peft_config = None,
1147
+ compute_metrics = None,
1148
+ callbacks = None,
1149
+ preprocess_logits_for_metrics = None,
1150
+ reward_model = None,
1151
+ **kwargs
1152
+ ):
1153
+ if args is None: args = UnslothNashMDConfig()
1154
+ use_bf16 = getattr(args, 'bf16', False)
1155
+ if type(use_bf16) is not bool: use_bf16 = False
1156
+ use_fp16 = getattr(args, 'fp16', False)
1157
+ if type(use_fp16) is not bool: use_fp16 = False
1158
+ force_float32 = False
1159
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1160
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1161
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1162
+ force_float32 = True
1163
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1164
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1165
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1166
+ from unsloth_zoo.utils import _get_dtype
1167
+ dtype = _get_dtype(dtype)
1168
+ float16 = dtype == torch.float16
1169
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1170
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1171
+ if force_float32:
1172
+ # Forced float32 training
1173
+ args.fp16 = False
1174
+ args.bf16 = False
1175
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1176
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1177
+ # args.mixed_precision is a new argument which needs to be set now
1178
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1179
+ # Mixed precision training
1180
+ args.fp16 = float16
1181
+ args.bf16 = not float16
1182
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1183
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1184
+ # args.mixed_precision is a new argument which needs to be set now
1185
+ elif mixed_precision_dtype == 'bfloat16':
1186
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1187
+ args.fp16 = False
1188
+ args.bf16 = False
1189
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1190
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1191
+ # args.mixed_precision is a new argument which needs to be set now
1192
+
1193
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1194
+ args.eval_strategy = 'steps'
1195
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1196
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1197
+ if ga_steps is not None and ga_steps > 1:
1198
+ from transformers import __version__ as transformers_version
1199
+ if Version(transformers_version) <= Version('4.45.2'):
1200
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1201
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1202
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1203
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1204
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1205
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1206
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1207
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1208
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1209
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1210
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1211
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1212
+ if force_float32:
1213
+ args.bf16_full_eval = False
1214
+ args.fp16_full_eval = False
1215
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1216
+ args.bf16_full_eval = True
1217
+ args.fp16_full_eval = False
1218
+ elif not bf16_full_eval and not fp16_full_eval:
1219
+ args.bf16_full_eval = args.bf16
1220
+ args.fp16_full_eval = args.fp16
1221
+ _output_logits = False
1222
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1223
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1224
+ if _output_logits:
1225
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1226
+ if model is not None:
1227
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1228
+ if _warnings_issued is None:
1229
+ model.warnings_issued = {}
1230
+ elif not isinstance(_warnings_issued, dict):
1231
+ try:
1232
+ model.warnings_issued = dict(_warnings_issued)
1233
+ except Exception:
1234
+ model.warnings_issued = {}
1235
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1236
+ pass
1237
+ else:
1238
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1239
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1240
+ if args_max_seq_length is None and model_max_seq_length is not None:
1241
+ max_seq_length = model.max_seq_length
1242
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1243
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1244
+ if args_max_seq_length > model_max_seq_length:
1245
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1246
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1247
+ args.max_seq_length = model_max_seq_length
1248
+ if model is not None and hasattr(model, 'for_training'):
1249
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1250
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1251
+ if 'processing_class' in locals():
1252
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1253
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1254
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1255
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1256
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1257
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1258
+ data_collator = TransformersDataCollatorForLanguageModeling(
1259
+ __tokenizer,
1260
+ mlm = False,
1261
+ mlm_probability = 0.0,
1262
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1263
+ )
1264
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1265
+ data_collator = DataCollatorForSeq2Seq(
1266
+ __tokenizer,
1267
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1268
+ )
1269
+ else:
1270
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1271
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1272
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1273
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1274
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1275
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1276
+ data_collator = DataCollatorForSeq2Seq(
1277
+ __tokenizer.tokenizer,
1278
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1279
+ )
1280
+ else:
1281
+ data_collator = TransformersDataCollatorForLanguageModeling(
1282
+ __tokenizer.tokenizer,
1283
+ mlm = False,
1284
+ mlm_probability = 0.0,
1285
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1286
+ )
1287
+ other_metrics = []
1288
+
1289
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1290
+ PatchRLStatistics('nash_md_trainer', other_metrics)
1291
+
1292
+ # [TODO] Fix up DataParallel multiplying batch sizes
1293
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1294
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1295
+ if getattr(args, "_n_gpu", 1) != 1:
1296
+ args._n_gpu = 1
1297
+ if "model" in locals() and hasattr(model, "for_training"):
1298
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1299
+ super().__init__(
1300
+ model = model,
1301
+ ref_model = ref_model,
1302
+ reward_funcs = reward_funcs,
1303
+ judge = judge,
1304
+ args = args,
1305
+ data_collator = data_collator,
1306
+ train_dataset = train_dataset,
1307
+ eval_dataset = eval_dataset,
1308
+ processing_class = processing_class,
1309
+ peft_config = peft_config,
1310
+ compute_metrics = compute_metrics,
1311
+ callbacks = callbacks,
1312
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1313
+ reward_model = reward_model,**kwargs)
1314
+ if "model" in locals() and hasattr(model, "for_inference"):
1315
+ model.for_inference()
1316
+ if hasattr(self, 'neftune_hook_handle'):
1317
+ self.neftune_hook_handle.remove()
1318
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1319
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1320
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1321
+ pass
1322
+ if hasattr(self, 'accelerator'):
1323
+ scaler = self.accelerator.scaler
1324
+ current_model = model
1325
+ while hasattr(current_model, 'model'):
1326
+ current_model.accelerator_scaler = scaler
1327
+ current_model = current_model.model
1328
+ current_model.accelerator_scaler = scaler
1329
+ pass
1330
+ if hasattr(self, 'train'):
1331
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1332
+ pass
1333
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1334
+ _vllm_tok = self.llm.get_tokenizer()
1335
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1336
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1337
+ _vllm_tok.chat_template = _pc.chat_template
1338
+ pass
1339
+
1340
+ pass
unsloth_compiled_cache/UnslothORPOTrainer.py CHANGED
@@ -28,7 +28,7 @@ import torch.nn as nn
28
  from torch.nn import functional as F
29
  from unsloth_zoo.temporary_patches.common import torch_compile
30
  from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
- from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
32
 
33
 
34
  import os
 
28
  from torch.nn import functional as F
29
  from unsloth_zoo.temporary_patches.common import torch_compile
30
  from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, wandb, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, wandb, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
32
 
33
 
34
  import os
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothRLOOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothXPOTrainer.py ADDED
@@ -0,0 +1,1385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2026.3.2
3
+ 2026.3.4
4
+ 5.3.0
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from unsloth_zoo.temporary_patches.common import torch_compile
30
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
31
+ from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation)
32
+
33
+
34
+ import os
35
+ import math
36
+ import logging
37
+ from typing import *
38
+ from dataclasses import dataclass, field
39
+ from packaging.version import Version
40
+ import torch
41
+ import numpy as np
42
+ from contextlib import nullcontext
43
+ from torch.nn import functional as F
44
+ import inspect
45
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
46
+ from transformers.training_args import ParallelMode
47
+ from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
48
+
49
+ # Wrap trainer with padding to right and enable training mode
50
+ # Also patches W&B since multiple runs must use wandb.finish()
51
+ import functools
52
+ from types import MethodType
53
+ try:
54
+ from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
55
+ except:
56
+ def reset_unsloth_gradient_checkpointing_buffers(): pass
57
+ def prepare_for_training_mode(f):
58
+ @functools.wraps(f)
59
+ def wrapper(self, *args, **kwargs):
60
+ # Enable training mode
61
+ _was_training = None
62
+ # Get gradient checkpointing setting from training arguments
63
+ use_gc = getattr(self.args, 'gradient_checkpointing', True)
64
+ if hasattr(self, 'model') and hasattr(self.model, "training"):
65
+ _was_training = self.model.training
66
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
67
+ self.model.for_training(use_gradient_checkpointing=use_gc)
68
+ output = f(self, *args, **kwargs)
69
+ # Restore previous mode when possible
70
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
71
+ if _was_training is False:
72
+ self.model.for_inference()
73
+ elif _was_training is True and hasattr(self.model, "for_training"):
74
+ self.model.for_training(use_gradient_checkpointing=use_gc)
75
+ # Reset gradient checkpointing buffers to free memory while staying ready for next run
76
+ try:
77
+ reset_unsloth_gradient_checkpointing_buffers()
78
+ except:
79
+ pass
80
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
81
+ try:
82
+ import wandb
83
+ wandb.finish()
84
+ except:
85
+ pass
86
+ return output
87
+ return wrapper
88
+ pass
89
+
90
+ torch_compile_options = {
91
+ "epilogue_fusion" : True,
92
+ "max_autotune" : False,
93
+ "shape_padding" : True,
94
+ "trace.enabled" : False,
95
+ "triton.cudagraphs" : False,
96
+ }
97
+
98
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
99
+ def chunked_hidden_states_selective_log_softmax(
100
+ hidden_states: torch.Tensor,
101
+ lm_head: torch.Tensor,
102
+ index: torch.Tensor,
103
+ chunks: int = 4,
104
+ logit_scale_multiply: float = 0.0,
105
+ logit_scale_divide: float = 0.0,
106
+ logit_softcapping: float = 0.0,
107
+ temperature: float = 1.0,
108
+ ) -> torch.Tensor:
109
+ # All Unsloth Zoo code licensed under AGPL3
110
+ flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
111
+ flat_index = index.reshape(-1)
112
+
113
+ chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
114
+ chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
115
+
116
+ all_per_token_logps = []
117
+
118
+ for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
119
+ chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
120
+
121
+ if logit_scale_multiply != 0.0:
122
+ chunk_logits = chunk_logits * logit_scale_multiply
123
+ if logit_scale_divide != 0.0:
124
+ chunk_logits = chunk_logits / logit_scale_divide
125
+ if logit_softcapping != 0.0:
126
+ chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
127
+
128
+ chunk_logits = chunk_logits.to(torch.float32)
129
+
130
+ if temperature != 1.0:
131
+ chunk_logits = chunk_logits / temperature
132
+
133
+ selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
134
+ logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
135
+ per_token_logps = selected_logits - logsumexp_values
136
+ all_per_token_logps.append(per_token_logps)
137
+
138
+ all_per_token_logps = torch.concat(all_per_token_logps)
139
+
140
+ all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
141
+ return all_per_token_logps
142
+
143
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
144
+ def chunked_selective_log_softmax(logits, index):
145
+ # Split into 4 chunks only
146
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
147
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
148
+ all_per_token_logps = []
149
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
150
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
151
+ chunk_logits = chunk_logits.to(torch.float32)
152
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
153
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
154
+ per_token_logps = selected_logits - logsumexp_values
155
+ all_per_token_logps.append(per_token_logps)
156
+ pass
157
+ all_per_token_logps = torch.concat(all_per_token_logps)
158
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
159
+ return all_per_token_logps
160
+
161
+ def calculate_pad_tokens_in_prompt(
162
+ input_ids: torch.Tensor,
163
+ logits_to_keep: int,
164
+ pad_token_id: int
165
+ ) -> torch.Tensor:
166
+ """
167
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
168
+ """
169
+ if logits_to_keep >= input_ids.shape[1]:
170
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
171
+
172
+ prompt_section = input_ids[:, :-logits_to_keep]
173
+
174
+ padding_mask = (prompt_section == pad_token_id)
175
+
176
+ pad_token_counts = padding_mask.sum(dim=1)
177
+
178
+ return pad_token_counts
179
+
180
+ def create_completion_attention_mask(
181
+ completion_input_ids: torch.Tensor,
182
+ left_pad_tokens_per_prompt: torch.Tensor,
183
+ max_left_pad: int,
184
+ pad_token_id: int
185
+ ) -> torch.Tensor:
186
+ """
187
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
188
+
189
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
190
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
191
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
192
+ """
193
+ batch_size, completion_len = completion_input_ids.shape
194
+ device = completion_input_ids.device
195
+
196
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
197
+
198
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
199
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
200
+
201
+ non_padding_mask = (completion_input_ids != pad_token_id)
202
+
203
+ final_mask = shift_mask & non_padding_mask
204
+
205
+ return final_mask
206
+
207
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
208
+ """
209
+ Moves all padding tokens in each sequence of a batch to the right.
210
+ """
211
+ mask = (tensor != pad_id)
212
+ # Must do stable=True since binary mark is unordered
213
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
214
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
215
+ return packed_tensor
216
+
217
+ def align_logprobs_with_mask(
218
+ logprob_tensor: torch.Tensor,
219
+ attention_mask: torch.Tensor,
220
+ pad_value: float = 0.0
221
+ ) -> torch.Tensor:
222
+ """
223
+ Aligns a log probability tensor with a given attention mask.
224
+ """
225
+
226
+ device = logprob_tensor.device
227
+ batch_size, logprob_seq_len = logprob_tensor.shape
228
+ mask_seq_len = attention_mask.shape[1]
229
+
230
+ padded_logprobs = torch.full(
231
+ attention_mask.shape,
232
+ fill_value=pad_value,
233
+ dtype=logprob_tensor.dtype,
234
+ device=device
235
+ )
236
+
237
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
238
+
239
+ cols = torch.arange(logprob_seq_len, device=device)
240
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
241
+
242
+ # Create destination row indices
243
+ # Shape: [batch_size, logprob_seq_len]
244
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
245
+
246
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
247
+ # Create a mask to identify only the indices that are within the bounds
248
+ # of the target tensor's sequence length.
249
+ valid_mask = dest_indices < mask_seq_len
250
+
251
+ # Use this mask to select only the valid row indices, column indices,
252
+ # and the corresponding values from the logprob tensor.
253
+ # This flattens the selected elements into 1D tensors.
254
+ valid_rows = row_indices[valid_mask]
255
+ valid_cols = dest_indices[valid_mask]
256
+ valid_vals = logprob_tensor[valid_mask]
257
+
258
+ # Place the valid values into their correct positions in the padded tensor
259
+ # using a single, efficient advanced indexing operation.
260
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
261
+
262
+ return padded_logprobs
263
+
264
+ def autotune_batch_and_chunks(
265
+ total_input_rows,
266
+ seq_len,
267
+ hidden_size,
268
+ vocab_size,
269
+ dtype_bytes=16,
270
+ multiplier=None
271
+ ):
272
+ if multiplier is None:
273
+ final_m = max(4, seq_len // 4096)
274
+ else:
275
+ final_m = multiplier
276
+
277
+ if torch.cuda.is_available():
278
+ free_bytes, _ = torch.cuda.mem_get_info()
279
+ limit_gb = (free_bytes / (1024**3))*.80
280
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
281
+ # For XPU: estimate free memory from total - reserved
282
+ total_mem = torch.xpu.get_device_properties(0).total_memory
283
+ reserved_mem = torch.xpu.memory_reserved()
284
+ free_bytes = total_mem - reserved_mem
285
+ limit_gb = (free_bytes / (1024**3)) * 0.80
286
+ else:
287
+ # Fallback: assume 8GB available
288
+ limit_gb = 8.0
289
+
290
+ bytes_to_gb = 1024**3
291
+
292
+ b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
293
+
294
+ hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
295
+
296
+ base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
297
+ logits_gb = base_logits / final_m
298
+
299
+ total_mem_gb = hidden_gb + logits_gb
300
+
301
+ valid_mask = total_mem_gb <= limit_gb
302
+ valid_indices = torch.nonzero(valid_mask, as_tuple=False)
303
+
304
+ if valid_indices.shape[0] == 0:
305
+ #This means your GPU will OOM
306
+ return 4, final_m
307
+
308
+ best_idx = valid_indices[0].item()
309
+ final_b = int(b_vals[best_idx].item())
310
+
311
+ return final_b, final_m
312
+
313
+ def sanitize_logprob(logprob):
314
+ """Local port of trl.scripts.vllm_serve.sanitize_logprob.
315
+ Filters NaN logprobs from vLLM outputs."""
316
+ value = logprob.logprob
317
+ if math.isnan(value):
318
+ logging.getLogger(__name__).warning(
319
+ f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
320
+ )
321
+ return None
322
+ return value
323
+ @dataclass
324
+ class UnslothXPOConfig(XPOConfig):
325
+ """
326
+
327
+ Configuration class for the [`XPOTrainer`].
328
+
329
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
330
+
331
+ Parameters:
332
+ alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
333
+ Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
334
+ and the last alpha is used for the rest of the epochs.
335
+
336
+ """
337
+ vllm_sampling_params: Optional[Any] = field(
338
+ default = None,
339
+ metadata = {'help': 'vLLM SamplingParams'},
340
+ )
341
+ unsloth_num_chunks : Optional[int] = field(
342
+ default = -1,
343
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
344
+ )
345
+ unsloth_logit_chunk_multiplier : Optional[int] = field(
346
+ default = None,
347
+ metadata = {'help': 'Multiplier for chunked logit computations.'},
348
+ )
349
+ unsloth_grpo_mini_batch : Optional[int] = field(
350
+ default = None,
351
+ metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
352
+ )
353
+ max_seq_length : Optional[int] = field(
354
+ default = None,
355
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
356
+ )
357
+ def __init__(
358
+ self,
359
+ output_dir = None,
360
+ per_device_train_batch_size = 4,
361
+ num_train_epochs = 3.0,
362
+ max_steps = -1,
363
+ learning_rate = 5e-05,
364
+ lr_scheduler_type = 'linear',
365
+ lr_scheduler_kwargs = None,
366
+ warmup_steps = 0.1,
367
+ optim = 'adamw_8bit',
368
+ optim_args = None,
369
+ weight_decay = 0.01,
370
+ adam_beta1 = 0.9,
371
+ adam_beta2 = 0.999,
372
+ adam_epsilon = 1e-08,
373
+ optim_target_modules = None,
374
+ gradient_accumulation_steps = 2,
375
+ average_tokens_across_devices = True,
376
+ max_grad_norm = 1.0,
377
+ label_smoothing_factor = 0.0,
378
+ bf16 = False,
379
+ fp16 = False,
380
+ bf16_full_eval = False,
381
+ fp16_full_eval = False,
382
+ tf32 = None,
383
+ gradient_checkpointing = True,
384
+ gradient_checkpointing_kwargs = None,
385
+ torch_compile = False,
386
+ torch_compile_backend = None,
387
+ torch_compile_mode = None,
388
+ use_liger_kernel = False,
389
+ liger_kernel_config = None,
390
+ use_cache = False,
391
+ neftune_noise_alpha = None,
392
+ torch_empty_cache_steps = 250,
393
+ auto_find_batch_size = False,
394
+ logging_strategy = 'steps',
395
+ logging_steps = 1,
396
+ logging_first_step = False,
397
+ log_on_each_node = True,
398
+ logging_nan_inf_filter = False,
399
+ include_num_input_tokens_seen = False,
400
+ log_level = 'passive',
401
+ log_level_replica = 'warning',
402
+ disable_tqdm = None,
403
+ report_to = 'none',
404
+ run_name = None,
405
+ project = 'huggingface',
406
+ trackio_space_id = 'trackio',
407
+ eval_strategy = 'no',
408
+ eval_steps = None,
409
+ eval_delay = 0,
410
+ per_device_eval_batch_size = 4,
411
+ prediction_loss_only = False,
412
+ eval_on_start = False,
413
+ eval_do_concat_batches = True,
414
+ eval_use_gather_object = False,
415
+ eval_accumulation_steps = 2,
416
+ batch_eval_metrics = False,
417
+ save_only_model = False,
418
+ save_strategy = 'steps',
419
+ save_steps = 500,
420
+ save_on_each_node = False,
421
+ save_total_limit = None,
422
+ enable_jit_checkpoint = False,
423
+ push_to_hub = False,
424
+ hub_token = None,
425
+ hub_private_repo = None,
426
+ hub_model_id = None,
427
+ hub_strategy = 'every_save',
428
+ hub_always_push = False,
429
+ hub_revision = None,
430
+ load_best_model_at_end = False,
431
+ metric_for_best_model = None,
432
+ greater_is_better = None,
433
+ ignore_data_skip = False,
434
+ restore_callback_states_from_checkpoint = False,
435
+ full_determinism = False,
436
+ seed = 3407,
437
+ data_seed = 3407,
438
+ use_cpu = False,
439
+ accelerator_config = None,
440
+ parallelism_config = None,
441
+ dataloader_drop_last = False,
442
+ dataloader_num_workers = 0,
443
+ dataloader_pin_memory = True,
444
+ dataloader_persistent_workers = False,
445
+ dataloader_prefetch_factor = None,
446
+ remove_unused_columns = True,
447
+ label_names = None,
448
+ train_sampling_strategy = 'random',
449
+ length_column_name = 'length',
450
+ ddp_find_unused_parameters = None,
451
+ ddp_bucket_cap_mb = None,
452
+ ddp_broadcast_buffers = None,
453
+ ddp_backend = None,
454
+ ddp_timeout = 1800,
455
+ fsdp = None,
456
+ fsdp_config = None,
457
+ deepspeed = None,
458
+ debug = '',
459
+ skip_memory_metrics = True,
460
+ do_train = False,
461
+ do_eval = False,
462
+ do_predict = False,
463
+ resume_from_checkpoint = None,
464
+ warmup_ratio = None,
465
+ logging_dir = None,
466
+ local_rank = -1,
467
+ reward_model_path = None,
468
+ judge = None,
469
+ max_new_tokens = 64,
470
+ max_length = 512,
471
+ temperature = 0.9,
472
+ top_p = 1.0,
473
+ top_k = None,
474
+ min_p = None,
475
+ repetition_penalty = 1.0,
476
+ generation_kwargs = {},
477
+ use_transformers_paged = False,
478
+ cache_implementation = None,
479
+ missing_eos_penalty = None,
480
+ loss_type = 'sigmoid',
481
+ disable_dropout = True,
482
+ use_vllm = False,
483
+ vllm_model_impl = 'vllm',
484
+ vllm_guided_decoding_regex = None,
485
+ vllm_gpu_memory_utilization = 0.55,
486
+ vllm_mode = 'colocate',
487
+ vllm_server_base_url = None,
488
+ vllm_server_host = '0.0.0.0',
489
+ vllm_server_port = 8000,
490
+ vllm_server_timeout = 240.0,
491
+ vllm_tensor_parallel_size = 1,
492
+ ds3_gather_for_generation = True,
493
+ model_init_kwargs = None,
494
+ reward_weights = None,
495
+ dataset_num_proc = None,
496
+ gpu_memory_utilization = None,
497
+ vllm_sampling_params = None,
498
+ unsloth_num_chunks = -1,
499
+ unsloth_logit_chunk_multiplier = None,
500
+ unsloth_grpo_mini_batch = None,
501
+ max_seq_length = None,
502
+ **kwargs,
503
+ ):
504
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
505
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
506
+ if num_train_epochs is None:
507
+ num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
508
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
509
+ output_dir = 'unsloth_training_checkpoints'
510
+ save_strategy = 'no'
511
+ import multiprocessing as _mp
512
+ if _mp.get_start_method() != 'fork':
513
+ dataset_num_proc = None
514
+ elif dataset_num_proc is None:
515
+ import psutil
516
+ dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
517
+ memory_gb_left = psutil.virtual_memory().available / (1024**3)
518
+ if memory_gb_left <= 2: dataset_num_proc = 1
519
+ else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
520
+ if temperature <= 0:
521
+ raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
522
+ elif temperature >= 10:
523
+ raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
524
+
525
+
526
+ super().__init__(
527
+ output_dir = output_dir,
528
+ per_device_train_batch_size = per_device_train_batch_size,
529
+ num_train_epochs = num_train_epochs,
530
+ max_steps = max_steps,
531
+ learning_rate = learning_rate,
532
+ lr_scheduler_type = lr_scheduler_type,
533
+ lr_scheduler_kwargs = lr_scheduler_kwargs,
534
+ warmup_steps = warmup_steps,
535
+ optim = optim,
536
+ optim_args = optim_args,
537
+ weight_decay = weight_decay,
538
+ adam_beta1 = adam_beta1,
539
+ adam_beta2 = adam_beta2,
540
+ adam_epsilon = adam_epsilon,
541
+ optim_target_modules = optim_target_modules,
542
+ gradient_accumulation_steps = gradient_accumulation_steps,
543
+ average_tokens_across_devices = average_tokens_across_devices,
544
+ max_grad_norm = max_grad_norm,
545
+ label_smoothing_factor = label_smoothing_factor,
546
+ bf16 = bf16,
547
+ fp16 = fp16,
548
+ bf16_full_eval = bf16_full_eval,
549
+ fp16_full_eval = fp16_full_eval,
550
+ tf32 = tf32,
551
+ gradient_checkpointing = gradient_checkpointing,
552
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
553
+ torch_compile = torch_compile,
554
+ torch_compile_backend = torch_compile_backend,
555
+ torch_compile_mode = torch_compile_mode,
556
+ use_liger_kernel = use_liger_kernel,
557
+ liger_kernel_config = liger_kernel_config,
558
+ use_cache = use_cache,
559
+ neftune_noise_alpha = neftune_noise_alpha,
560
+ torch_empty_cache_steps = torch_empty_cache_steps,
561
+ auto_find_batch_size = auto_find_batch_size,
562
+ logging_strategy = logging_strategy,
563
+ logging_steps = logging_steps,
564
+ logging_first_step = logging_first_step,
565
+ log_on_each_node = log_on_each_node,
566
+ logging_nan_inf_filter = logging_nan_inf_filter,
567
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
568
+ log_level = log_level,
569
+ log_level_replica = log_level_replica,
570
+ disable_tqdm = disable_tqdm,
571
+ report_to = report_to,
572
+ run_name = run_name,
573
+ project = project,
574
+ trackio_space_id = trackio_space_id,
575
+ eval_strategy = eval_strategy,
576
+ eval_steps = eval_steps,
577
+ eval_delay = eval_delay,
578
+ per_device_eval_batch_size = per_device_eval_batch_size,
579
+ prediction_loss_only = prediction_loss_only,
580
+ eval_on_start = eval_on_start,
581
+ eval_do_concat_batches = eval_do_concat_batches,
582
+ eval_use_gather_object = eval_use_gather_object,
583
+ eval_accumulation_steps = eval_accumulation_steps,
584
+ batch_eval_metrics = batch_eval_metrics,
585
+ save_only_model = save_only_model,
586
+ save_strategy = save_strategy,
587
+ save_steps = save_steps,
588
+ save_on_each_node = save_on_each_node,
589
+ save_total_limit = save_total_limit,
590
+ enable_jit_checkpoint = enable_jit_checkpoint,
591
+ push_to_hub = push_to_hub,
592
+ hub_token = hub_token,
593
+ hub_private_repo = hub_private_repo,
594
+ hub_model_id = hub_model_id,
595
+ hub_strategy = hub_strategy,
596
+ hub_always_push = hub_always_push,
597
+ hub_revision = hub_revision,
598
+ load_best_model_at_end = load_best_model_at_end,
599
+ metric_for_best_model = metric_for_best_model,
600
+ greater_is_better = greater_is_better,
601
+ ignore_data_skip = ignore_data_skip,
602
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
603
+ full_determinism = full_determinism,
604
+ seed = seed,
605
+ data_seed = data_seed,
606
+ use_cpu = use_cpu,
607
+ accelerator_config = accelerator_config,
608
+ parallelism_config = parallelism_config,
609
+ dataloader_drop_last = dataloader_drop_last,
610
+ dataloader_num_workers = dataloader_num_workers,
611
+ dataloader_pin_memory = dataloader_pin_memory,
612
+ dataloader_persistent_workers = dataloader_persistent_workers,
613
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
614
+ remove_unused_columns = remove_unused_columns,
615
+ label_names = label_names,
616
+ train_sampling_strategy = train_sampling_strategy,
617
+ length_column_name = length_column_name,
618
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
619
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
620
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
621
+ ddp_backend = ddp_backend,
622
+ ddp_timeout = ddp_timeout,
623
+ fsdp = fsdp,
624
+ fsdp_config = fsdp_config,
625
+ deepspeed = deepspeed,
626
+ debug = debug,
627
+ skip_memory_metrics = skip_memory_metrics,
628
+ do_train = do_train,
629
+ do_eval = do_eval,
630
+ do_predict = do_predict,
631
+ resume_from_checkpoint = resume_from_checkpoint,
632
+ warmup_ratio = warmup_ratio,
633
+ logging_dir = logging_dir,
634
+ local_rank = local_rank,
635
+ reward_model_path = reward_model_path,
636
+ judge = judge,
637
+ max_new_tokens = max_new_tokens,
638
+ max_length = max_length,
639
+ temperature = temperature,
640
+ top_p = top_p,
641
+ top_k = top_k,
642
+ min_p = min_p,
643
+ repetition_penalty = repetition_penalty,
644
+ generation_kwargs = generation_kwargs,
645
+ use_transformers_paged = use_transformers_paged,
646
+ cache_implementation = cache_implementation,
647
+ missing_eos_penalty = missing_eos_penalty,
648
+ loss_type = loss_type,
649
+ disable_dropout = disable_dropout,
650
+ use_vllm = use_vllm,
651
+ vllm_model_impl = vllm_model_impl,
652
+ vllm_guided_decoding_regex = vllm_guided_decoding_regex,
653
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
654
+ vllm_mode = vllm_mode,
655
+ vllm_server_base_url = vllm_server_base_url,
656
+ vllm_server_host = vllm_server_host,
657
+ vllm_server_port = vllm_server_port,
658
+ vllm_server_timeout = vllm_server_timeout,
659
+ vllm_tensor_parallel_size = vllm_tensor_parallel_size,
660
+ ds3_gather_for_generation = ds3_gather_for_generation,
661
+ model_init_kwargs = model_init_kwargs,
662
+ reward_weights = reward_weights,
663
+ dataset_num_proc = dataset_num_proc,
664
+ gpu_memory_utilization = gpu_memory_utilization,**kwargs)
665
+ self.vllm_sampling_params = vllm_sampling_params
666
+ self.unsloth_num_chunks = unsloth_num_chunks
667
+ if unsloth_grpo_mini_batch is not None:
668
+ if self.generation_batch_size >= unsloth_grpo_mini_batch:
669
+ self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
670
+ else:
671
+ raise ValueError(
672
+ f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
673
+ f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
674
+ )
675
+ self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
676
+ self.max_seq_length = max_seq_length
677
+
678
+ pass
679
+
680
+ class _UnslothXPOTrainer(OnlineDPOTrainer):
681
+ """"""
682
+
683
+ _tag_names = ["trl", "xpo"]
684
+ _name = "XPO"
685
+ _paper = {
686
+ "title": "Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
687
+ "id": "2405.21046",
688
+ # docstyle-ignore
689
+ "citation": textwrap.dedent("""\
690
+ @article{jung2024binary,
691
+ title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
692
+ author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
693
+ year = 2024,
694
+ eprint = {arXiv:2405.21046}
695
+ }"""),
696
+ }
697
+
698
+ def __init__(
699
+ self,
700
+ model: Union[PreTrainedModel, nn.Module] = None,
701
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
702
+ reward_funcs: Optional[nn.Module] = None,
703
+ judge: Optional[BasePairwiseJudge] = None,
704
+ args: Optional[XPOConfig] = None,
705
+ data_collator: Optional[Callable] = None,
706
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
707
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
708
+ processing_class: Optional[
709
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
710
+ ] = None,
711
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
712
+ peft_config: Optional[dict] = None,
713
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
714
+ callbacks: Optional[list[TrainerCallback]] = None,
715
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
716
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
717
+ # Deprecated parameters
718
+ reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
719
+ ) -> None:
720
+ super().__init__(
721
+ model=model,
722
+ ref_model=ref_model,
723
+ judge=judge,
724
+ reward_funcs=reward_funcs,
725
+ reward_model=reward_model,
726
+ args=args,
727
+ data_collator=data_collator,
728
+ train_dataset=train_dataset,
729
+ eval_dataset=eval_dataset,
730
+ processing_class=processing_class,
731
+ reward_processing_classes=reward_processing_classes,
732
+ peft_config=peft_config,
733
+ compute_metrics=compute_metrics,
734
+ callbacks=callbacks,
735
+ optimizers=optimizers,
736
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
737
+ )
738
+
739
+ self._alpha = self.args.alpha
740
+
741
+ # Overwrite the stats dictionary to include XPO specific statistics
742
+ self.stats = {
743
+ # Remove "non_score_reward", "rlhf_reward", "scores"
744
+ # Add "loss/dpo", "loss/xpo"
745
+ "loss/dpo": [],
746
+ "loss/xpo": [],
747
+ "objective/kl": [],
748
+ "objective/entropy": [],
749
+ "rewards/chosen": [],
750
+ "rewards/rejected": [],
751
+ "rewards/accuracies": [],
752
+ "rewards/margins": [],
753
+ "logps/chosen": [],
754
+ "logps/rejected": [],
755
+ # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
756
+ "val/model_contain_eos_token": [],
757
+ "val/ref_contain_eos_token": [],
758
+ "alpha": [],
759
+ "beta": [],
760
+ }
761
+ if self.reward_funcs is not None:
762
+ if len(self.reward_funcs) != 1:
763
+ raise ValueError("XPOTrainer only supports one reward function/model.")
764
+ self.reward_funcs = self.reward_funcs[0]
765
+ self.stats["objective/model_scores"] = []
766
+ self.stats["objective/ref_scores"] = []
767
+ self.stats["objective/scores_margin"] = []
768
+
769
+ @property
770
+ def alpha(self):
771
+ if isinstance(self._alpha, list):
772
+ epoch = self.state.epoch
773
+ return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
774
+ else:
775
+ return self._alpha
776
+
777
+ def _generate_completions(self, prompts, model):
778
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
779
+ model_output = unwrapped_policy_model_for_gen.generate(
780
+ input_ids=prompts["input_ids"],
781
+ attention_mask=prompts["attention_mask"],
782
+ generation_config=self.generation_config,
783
+ )
784
+
785
+ actual_model_for_ref_generation: torch.nn.Module
786
+ if self.ref_model is None:
787
+ unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
788
+
789
+ if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
790
+ actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
791
+ else:
792
+ actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
793
+ else:
794
+ actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
795
+
796
+ with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
797
+ ref_output = final_ref_model_for_gen.generate(
798
+ input_ids=prompts["input_ids"],
799
+ attention_mask=prompts["attention_mask"],
800
+ generation_config=self.generation_config,
801
+ )
802
+
803
+ return model_output, ref_output
804
+
805
+ def _process_completions(self, model_output, ref_output, prompts):
806
+ context_length = prompts["input_ids"].shape[1]
807
+
808
+ # Process model completions
809
+ model_completion_ids = model_output[:, context_length:]
810
+ model_completion_ids, model_completion_mask = truncate_right(
811
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
812
+ )
813
+ model_data = {
814
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
815
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
816
+ "raw": prompts["raw"],
817
+ }
818
+
819
+ # Process reference model completions
820
+ ref_completion_ids = ref_output[:, context_length:]
821
+ ref_completion_ids, ref_completion_mask = truncate_right(
822
+ ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
823
+ )
824
+ ref_data = {
825
+ "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
826
+ "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
827
+ "raw": prompts["raw"],
828
+ }
829
+
830
+ return model_data, ref_data
831
+
832
+ def _compute_rewards(self, model_data, ref_data, context_length):
833
+ with torch.no_grad():
834
+ _, model_scores, _ = get_reward(
835
+ self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length
836
+ )
837
+ _, ref_scores, _ = get_reward(
838
+ self.reward_funcs, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
839
+ )
840
+
841
+ # Apply EOS penalty if needed
842
+ if self.args.missing_eos_penalty is not None:
843
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
844
+ ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
845
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
846
+ ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
847
+
848
+ return model_scores, ref_scores
849
+
850
+ def _compute_judge(self, model_data, ref_data, context_length):
851
+ prompts = model_data["raw"]
852
+ model_data_completions = self.processing_class.batch_decode(
853
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
854
+ )
855
+ model_data_completions = [completion.strip() for completion in model_data_completions]
856
+
857
+ ref_data_completions = self.processing_class.batch_decode(
858
+ ref_data["input_ids"][:, context_length:], skip_special_tokens=True
859
+ )
860
+ ref_data_completions = [completion.strip() for completion in ref_data_completions]
861
+
862
+ if is_conversational({"prompt": prompts[0]}):
863
+ model_data_completions = [
864
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
865
+ ]
866
+ environment = jinja2.Environment()
867
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
868
+ prompts = [template.render(messages=message) for message in prompts]
869
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
870
+
871
+ ref_data_completions = [
872
+ [{"role": "assistant", "content": completion}] for completion in ref_data_completions
873
+ ]
874
+ ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
875
+
876
+ ranks_of_first_completion = self.judge.judge(
877
+ prompts,
878
+ list(zip(model_data_completions, ref_data_completions)),
879
+ )
880
+ # convert ranks to a True/False mask:
881
+ # when rank == 0, it means the first completion is the best
882
+ # when rank == 1, it means the second completion is the best
883
+ return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
884
+
885
+ def _compute_logprobs(self, model, model_data, ref_data, context_length):
886
+ def compute_logprobs_for_data(m, data):
887
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
888
+ logits = output.logits[:, context_length - 1 : -1]
889
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
890
+ return token_logprobs
891
+
892
+ # Compute logprobs for model completions
893
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
894
+ # Compute logprobs for model on reference completions (for XPO loss)
895
+ model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
896
+
897
+ # Compute logprobs for reference model completions
898
+ with torch.no_grad():
899
+ if self.ref_model is None:
900
+ with model.disable_adapter():
901
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
902
+ ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
903
+ else:
904
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
905
+ ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
906
+
907
+ # Mask padding tokens
908
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
909
+ ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
910
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
911
+ model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
912
+ ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
913
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
914
+
915
+ return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
916
+
917
+ def _compute_losses(
918
+ self,
919
+ model_logprobs_model_data,
920
+ model_logprobs_ref_data,
921
+ ref_logprobs_ref_data,
922
+ ref_logprobs_model_data,
923
+ chosen_mask,
924
+ ):
925
+ # Compute log probs
926
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
927
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
928
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
929
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
930
+
931
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
932
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
933
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
934
+
935
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
936
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
937
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
938
+
939
+ # Compute logits as the difference between chosen and rejected log ratios
940
+ logits = chosen_log_ratios - rejected_log_ratios
941
+
942
+ if self.args.loss_type == "sigmoid":
943
+ dpo_losses = -F.logsigmoid(self.beta * logits)
944
+ elif self.args.loss_type == "ipo":
945
+ dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
946
+ else:
947
+ raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
948
+
949
+ # Compute XPO specific loss
950
+ xpo_losses = self.alpha * model_logprobs_ref_data_sum
951
+
952
+ # Total loss
953
+ loss = (dpo_losses + xpo_losses).mean()
954
+
955
+ return loss, dpo_losses, xpo_losses
956
+
957
+ def _log_statistics(
958
+ self,
959
+ model_data,
960
+ ref_data,
961
+ model_logprobs_model_data,
962
+ model_logprobs_ref_data,
963
+ ref_logprobs_ref_data,
964
+ ref_logprobs_model_data,
965
+ chosen_mask,
966
+ dpo_losses,
967
+ xpo_losses,
968
+ context_length,
969
+ model_scores=None,
970
+ ref_scores=None,
971
+ ):
972
+ # Helper function to gather and compute mean
973
+ def gather_mean(tensor):
974
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
975
+
976
+ # Log losses
977
+ self.stats["loss/dpo"].append(gather_mean(dpo_losses))
978
+ self.stats["loss/xpo"].append(gather_mean(xpo_losses))
979
+
980
+ # Log scores
981
+ if self.reward_funcs is not None:
982
+ self.stats["objective/model_scores"].append(gather_mean(model_scores))
983
+ self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
984
+ self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
985
+
986
+ # Log logprobs
987
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
988
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
989
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
990
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
991
+
992
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
993
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
994
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
995
+
996
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
997
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
998
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
999
+
1000
+ self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
1001
+ self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
1002
+
1003
+ # Log rewards
1004
+ # Compute various statistics
1005
+ chosen_rewards = chosen_log_ratios * self.beta
1006
+ rejected_rewards = rejected_log_ratios * self.beta
1007
+ self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
1008
+ self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
1009
+
1010
+ # Calculate KL divergence for model and ref data
1011
+ kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
1012
+ kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
1013
+ mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
1014
+ self.stats["objective/kl"].append(gather_mean(mean_kl))
1015
+
1016
+ # Calculate entropy for model and ref data
1017
+ entropy_model_data = -model_logprobs_model_data.sum(1)
1018
+ entropy_ref_data = -model_logprobs_ref_data.sum(1)
1019
+ mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
1020
+ self.stats["objective/entropy"].append(gather_mean(mean_entropy))
1021
+
1022
+ # Calculate margins
1023
+ margin = chosen_rewards - rejected_rewards
1024
+ self.stats["rewards/margins"].append(gather_mean(margin.mean()))
1025
+
1026
+ # Calculate accuracy
1027
+ accuracy = (margin > 0).float()
1028
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
1029
+
1030
+ # Log EOS token statistics
1031
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
1032
+ ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
1033
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
1034
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
1035
+
1036
+ # Log alpha and beta
1037
+ self.stats["alpha"].append(self.alpha)
1038
+ self.stats["beta"].append(self.beta)
1039
+
1040
+ def training_step(
1041
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
1042
+ ) -> torch.Tensor:
1043
+ model.train()
1044
+
1045
+ # Apply chat template and tokenize the input
1046
+ batch_size = len(next(iter(inputs.values())))
1047
+ prompts = inputs["prompt"]
1048
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
1049
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
1050
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
1051
+ inputs = self.data_collator(inputs)
1052
+
1053
+ # need the prompt_ only
1054
+ inputs = self._prepare_inputs(inputs)
1055
+ context_length = inputs["prompt_input_ids"].shape[1]
1056
+ prompts = {
1057
+ "input_ids": inputs["prompt_input_ids"],
1058
+ "attention_mask": inputs["prompt_attention_mask"],
1059
+ "raw": prompts,
1060
+ }
1061
+ del inputs
1062
+
1063
+ # Sample completions from both the model and the reference model
1064
+ model_output, ref_output = self._generate_completions(prompts, model)
1065
+
1066
+ # Process model completions
1067
+ model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
1068
+
1069
+ # Compute rewards
1070
+ if self.reward_funcs is not None:
1071
+ model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
1072
+ chosen_mask = model_scores >= ref_scores
1073
+ else:
1074
+ model_scores, ref_scores = None, None
1075
+ chosen_mask = self._compute_judge(model_data, ref_data, context_length)
1076
+
1077
+ # Compute logprobs
1078
+ model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
1079
+ self._compute_logprobs(model, model_data, ref_data, context_length)
1080
+ )
1081
+
1082
+ # Compute loss
1083
+ loss, dpo_losses, xpo_losses = self._compute_losses(
1084
+ model_logprobs_model_data,
1085
+ model_logprobs_ref_data,
1086
+ ref_logprobs_ref_data,
1087
+ ref_logprobs_model_data,
1088
+ chosen_mask,
1089
+ )
1090
+
1091
+ # Log everything
1092
+ self._log_statistics(
1093
+ model_data,
1094
+ ref_data,
1095
+ model_logprobs_model_data.detach(),
1096
+ model_logprobs_ref_data.detach(),
1097
+ ref_logprobs_ref_data,
1098
+ ref_logprobs_model_data,
1099
+ chosen_mask,
1100
+ dpo_losses.detach(),
1101
+ xpo_losses.detach(),
1102
+ context_length,
1103
+ model_scores,
1104
+ ref_scores,
1105
+ )
1106
+
1107
+ if (
1108
+ self.args.torch_empty_cache_steps is not None
1109
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
1110
+ ):
1111
+ empty_cache()
1112
+
1113
+ kwargs = {}
1114
+ # For LOMO optimizers you need to explicitly use the learning rate
1115
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
1116
+ kwargs["learning_rate"] = self._get_learning_rate()
1117
+
1118
+ if self.args.n_gpu > 1:
1119
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
1120
+
1121
+ self.accelerator.backward(loss, **kwargs)
1122
+
1123
+ return loss.detach() / self.args.gradient_accumulation_steps
1124
+ class UnslothXPOTrainer(_UnslothXPOTrainer):
1125
+ """
1126
+
1127
+ Trainer for Exploratory Preference Optimization (XPO).
1128
+
1129
+ It is implemented as a subclass of [`OnlineDPOTrainer`].
1130
+
1131
+ Args:
1132
+ model ([`~transformers.PreTrainedModel`]):
1133
+ The model to train, preferably an `AutoModelForCausalLM`.
1134
+ ref_model ([`PreTrainedModelWrapper`]):
1135
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
1136
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
1137
+ architecture as the model to be optimized.
1138
+ reward_funcs ([`~transformers.PreTrainedModel`]):
1139
+ The reward model to score completions with, preferably an
1140
+ [`~transformers.AutoModelForSequenceClassification`].
1141
+ judge ([`BasePairwiseJudge`]):
1142
+ The judge to use for pairwise comparison of model completions.
1143
+ args ([`XPOConfig`]):
1144
+ The XPO config arguments to use for training.
1145
+ data_collator ([`~transformers.DataCollator`]):
1146
+ The data collator to use for training. If None is specified, the default data collator
1147
+ ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
1148
+ sequences in the batch, given a dataset of paired sequences.
1149
+ train_dataset ([`~datasets.Dataset`]):
1150
+ The dataset to use for training.
1151
+ eval_dataset ([`~datasets.Dataset`]):
1152
+ The dataset to use for evaluation.
1153
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
1154
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1155
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1156
+ reuse the fine-tuned model.
1157
+ peft_config (`dict`):
1158
+ The peft config to use for training.
1159
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1160
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1161
+ metric values.
1162
+ callbacks (`list[transformers.TrainerCallback]`):
1163
+ The callbacks to use for training.
1164
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1165
+ The optimizer and scheduler to use for training.
1166
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1167
+ The function to use to preprocess the logits before computing the metrics.
1168
+
1169
+ reward_model:
1170
+
1171
+ <Deprecated version="0.22.0">
1172
+
1173
+ This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
1174
+
1175
+ </Deprecated>
1176
+
1177
+ """
1178
+ def __init__(
1179
+ self,
1180
+ model = None,
1181
+ ref_model = None,
1182
+ reward_funcs = None,
1183
+ judge = None,
1184
+ args = None,
1185
+ data_collator = None,
1186
+ train_dataset = None,
1187
+ eval_dataset = None,
1188
+ processing_class = None,
1189
+ reward_processing_classes = None,
1190
+ peft_config = None,
1191
+ compute_metrics = None,
1192
+ callbacks = None,
1193
+ preprocess_logits_for_metrics = None,
1194
+ reward_model = None,
1195
+ **kwargs
1196
+ ):
1197
+ if args is None: args = UnslothXPOConfig()
1198
+ use_bf16 = getattr(args, 'bf16', False)
1199
+ if type(use_bf16) is not bool: use_bf16 = False
1200
+ use_fp16 = getattr(args, 'fp16', False)
1201
+ if type(use_fp16) is not bool: use_fp16 = False
1202
+ force_float32 = False
1203
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1204
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1205
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1206
+ force_float32 = True
1207
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1208
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1209
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1210
+ from unsloth_zoo.utils import _get_dtype
1211
+ dtype = _get_dtype(dtype)
1212
+ float16 = dtype == torch.float16
1213
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1214
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1215
+ if force_float32:
1216
+ # Forced float32 training
1217
+ args.fp16 = False
1218
+ args.bf16 = False
1219
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1220
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1221
+ # args.mixed_precision is a new argument which needs to be set now
1222
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1223
+ # Mixed precision training
1224
+ args.fp16 = float16
1225
+ args.bf16 = not float16
1226
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1227
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1228
+ # args.mixed_precision is a new argument which needs to be set now
1229
+ elif mixed_precision_dtype == 'bfloat16':
1230
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1231
+ args.fp16 = False
1232
+ args.bf16 = False
1233
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1234
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1235
+ # args.mixed_precision is a new argument which needs to be set now
1236
+
1237
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1238
+ args.eval_strategy = 'steps'
1239
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1240
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1241
+ if ga_steps is not None and ga_steps > 1:
1242
+ from transformers import __version__ as transformers_version
1243
+ if Version(transformers_version) <= Version('4.45.2'):
1244
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1245
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1246
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1247
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1248
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1249
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1250
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1251
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1252
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1253
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1254
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1255
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1256
+ if force_float32:
1257
+ args.bf16_full_eval = False
1258
+ args.fp16_full_eval = False
1259
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1260
+ args.bf16_full_eval = True
1261
+ args.fp16_full_eval = False
1262
+ elif not bf16_full_eval and not fp16_full_eval:
1263
+ args.bf16_full_eval = args.bf16
1264
+ args.fp16_full_eval = args.fp16
1265
+ _output_logits = False
1266
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1267
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1268
+ if _output_logits:
1269
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1270
+ if model is not None:
1271
+ _warnings_issued = getattr(model, 'warnings_issued', None)
1272
+ if _warnings_issued is None:
1273
+ model.warnings_issued = {}
1274
+ elif not isinstance(_warnings_issued, dict):
1275
+ try:
1276
+ model.warnings_issued = dict(_warnings_issued)
1277
+ except Exception:
1278
+ model.warnings_issued = {}
1279
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1280
+ pass
1281
+ else:
1282
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1283
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1284
+ if args_max_seq_length is None and model_max_seq_length is not None:
1285
+ max_seq_length = model.max_seq_length
1286
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1287
+ elif args_max_seq_length is not None and model_max_seq_length is not None:
1288
+ if args_max_seq_length > model_max_seq_length:
1289
+ print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
1290
+ 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
1291
+ args.max_seq_length = model_max_seq_length
1292
+ if model is not None and hasattr(model, 'for_training'):
1293
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1294
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1295
+ if 'processing_class' in locals():
1296
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1297
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1298
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1299
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1300
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1301
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1302
+ data_collator = TransformersDataCollatorForLanguageModeling(
1303
+ __tokenizer,
1304
+ mlm = False,
1305
+ mlm_probability = 0.0,
1306
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1307
+ )
1308
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1309
+ data_collator = DataCollatorForSeq2Seq(
1310
+ __tokenizer,
1311
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1312
+ )
1313
+ else:
1314
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1315
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1316
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1317
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1318
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1319
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1320
+ data_collator = DataCollatorForSeq2Seq(
1321
+ __tokenizer.tokenizer,
1322
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1323
+ )
1324
+ else:
1325
+ data_collator = TransformersDataCollatorForLanguageModeling(
1326
+ __tokenizer.tokenizer,
1327
+ mlm = False,
1328
+ mlm_probability = 0.0,
1329
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1330
+ )
1331
+ other_metrics = []
1332
+
1333
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1334
+ PatchRLStatistics('xpo_trainer', other_metrics)
1335
+
1336
+ # [TODO] Fix up DataParallel multiplying batch sizes
1337
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1338
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1339
+ if getattr(args, "_n_gpu", 1) != 1:
1340
+ args._n_gpu = 1
1341
+ if "model" in locals() and hasattr(model, "for_training"):
1342
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1343
+ super().__init__(
1344
+ model = model,
1345
+ ref_model = ref_model,
1346
+ reward_funcs = reward_funcs,
1347
+ judge = judge,
1348
+ args = args,
1349
+ data_collator = data_collator,
1350
+ train_dataset = train_dataset,
1351
+ eval_dataset = eval_dataset,
1352
+ processing_class = processing_class,
1353
+ reward_processing_classes = reward_processing_classes,
1354
+ peft_config = peft_config,
1355
+ compute_metrics = compute_metrics,
1356
+ callbacks = callbacks,
1357
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1358
+ reward_model = reward_model,**kwargs)
1359
+ if "model" in locals() and hasattr(model, "for_inference"):
1360
+ model.for_inference()
1361
+ if hasattr(self, 'neftune_hook_handle'):
1362
+ self.neftune_hook_handle.remove()
1363
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1364
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1365
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1366
+ pass
1367
+ if hasattr(self, 'accelerator'):
1368
+ scaler = self.accelerator.scaler
1369
+ current_model = model
1370
+ while hasattr(current_model, 'model'):
1371
+ current_model.accelerator_scaler = scaler
1372
+ current_model = current_model.model
1373
+ current_model.accelerator_scaler = scaler
1374
+ pass
1375
+ if hasattr(self, 'train'):
1376
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1377
+ pass
1378
+ if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
1379
+ _vllm_tok = self.llm.get_tokenizer()
1380
+ _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
1381
+ if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
1382
+ _vllm_tok.chat_template = _pc.chat_template
1383
+ pass
1384
+
1385
+ pass