Upload folder using huggingface_hub
Browse files- _dashboard_state.json +94 -123
- colab_train_llama32_remote.py +340 -0
- colab_train_unsloth.ipynb +347 -128
- pyproject.toml +0 -2
- tests/test_colab_train_llama32_remote.py +17 -0
- unsloth_2048.ipynb +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothCPOTrainer.py +1 -1
- unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothGRPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothKTOTrainer.py +1 -1
- unsloth_compiled_cache/UnslothNashMDTrainer.py +1340 -0
- unsloth_compiled_cache/UnslothORPOTrainer.py +1 -1
- unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothRLOOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothXPOTrainer.py +1385 -0
_dashboard_state.json
CHANGED
|
@@ -1,29 +1,29 @@
|
|
| 1 |
{
|
| 2 |
-
"timestamp":
|
| 3 |
"step": 20,
|
| 4 |
"episode_done": false,
|
| 5 |
-
"cumulative_reward":
|
| 6 |
-
"gen_time_s":
|
| 7 |
-
"model_response_raw": "{\n\"action_type\": \"design_followup_experiment\",\n\"method\": \"\",\n\"parameters\": {\n\"
|
| 8 |
"model_thinking": "",
|
| 9 |
"thinking_enabled": false,
|
| 10 |
"task": {
|
| 11 |
-
"problem_statement": "
|
| 12 |
"organism": "human",
|
| 13 |
-
"tissue": "
|
| 14 |
"modality": "scRNA-seq",
|
| 15 |
"conditions": [
|
| 16 |
"healthy",
|
| 17 |
-
"
|
| 18 |
],
|
| 19 |
-
"budget_limit":
|
| 20 |
-
"time_limit_days":
|
| 21 |
},
|
| 22 |
"resources": {
|
| 23 |
"budget_used": 28900.0,
|
| 24 |
-
"budget_remaining":
|
| 25 |
"time_used_days": 18.0,
|
| 26 |
-
"time_remaining_days":
|
| 27 |
"samples_consumed": 0,
|
| 28 |
"compute_hours_used": 12.0
|
| 29 |
},
|
|
@@ -32,9 +32,9 @@
|
|
| 32 |
"step_index": 1,
|
| 33 |
"action_type": "collect_sample",
|
| 34 |
"method": null,
|
| 35 |
-
"output_summary": "Collected 6 samples (quality=0.
|
| 36 |
"success": true,
|
| 37 |
-
"quality_score": 0.
|
| 38 |
"resource_cost": 5000.0,
|
| 39 |
"time_cost_days": 7.0
|
| 40 |
},
|
|
@@ -62,9 +62,9 @@
|
|
| 62 |
"step_index": 4,
|
| 63 |
"action_type": "prepare_library",
|
| 64 |
"method": null,
|
| 65 |
-
"output_summary": "Library prepared (complexity=0.
|
| 66 |
"success": true,
|
| 67 |
-
"quality_score": 0.
|
| 68 |
"resource_cost": 8000.0,
|
| 69 |
"time_cost_days": 3.0
|
| 70 |
},
|
|
@@ -72,42 +72,42 @@
|
|
| 72 |
"step_index": 5,
|
| 73 |
"action_type": "sequence_cells",
|
| 74 |
"method": null,
|
| 75 |
-
"output_summary": "Sequenced
|
| 76 |
"success": true,
|
| 77 |
-
"quality_score": 0.
|
| 78 |
"resource_cost": 15000.0,
|
| 79 |
"time_cost_days": 5.0
|
| 80 |
},
|
| 81 |
{
|
| 82 |
"step_index": 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
"action_type": "run_qc",
|
| 84 |
"method": null,
|
| 85 |
"output_summary": "QC metrics computed",
|
| 86 |
"success": true,
|
| 87 |
-
"quality_score": 0.
|
| 88 |
"resource_cost": 100.0,
|
| 89 |
"time_cost_days": 0.5
|
| 90 |
},
|
| 91 |
{
|
| 92 |
-
"step_index":
|
| 93 |
"action_type": "filter_data",
|
| 94 |
"method": null,
|
| 95 |
-
"output_summary": "Filtered
|
| 96 |
"success": true,
|
| 97 |
-
"quality_score": 0.
|
| 98 |
"resource_cost": 50.0,
|
| 99 |
"time_cost_days": 0.2
|
| 100 |
},
|
| 101 |
-
{
|
| 102 |
-
"step_index": 8,
|
| 103 |
-
"action_type": "cluster_cells",
|
| 104 |
-
"method": null,
|
| 105 |
-
"output_summary": "Action blocked: Cannot cluster before normalisation",
|
| 106 |
-
"success": false,
|
| 107 |
-
"quality_score": 1.0,
|
| 108 |
-
"resource_cost": 100.0,
|
| 109 |
-
"time_cost_days": 0.5
|
| 110 |
-
},
|
| 111 |
{
|
| 112 |
"step_index": 9,
|
| 113 |
"action_type": "normalize_data",
|
|
@@ -120,11 +120,11 @@
|
|
| 120 |
},
|
| 121 |
{
|
| 122 |
"step_index": 10,
|
| 123 |
-
"action_type": "
|
| 124 |
"method": null,
|
| 125 |
-
"output_summary": "
|
| 126 |
-
"success":
|
| 127 |
-
"quality_score":
|
| 128 |
"resource_cost": 100.0,
|
| 129 |
"time_cost_days": 0.5
|
| 130 |
},
|
|
@@ -132,20 +132,20 @@
|
|
| 132 |
"step_index": 11,
|
| 133 |
"action_type": "design_followup_experiment",
|
| 134 |
"method": null,
|
| 135 |
-
"output_summary": "
|
| 136 |
-
"success":
|
| 137 |
-
"quality_score":
|
| 138 |
"resource_cost": 100.0,
|
| 139 |
"time_cost_days": 0.5
|
| 140 |
},
|
| 141 |
{
|
| 142 |
"step_index": 12,
|
| 143 |
-
"action_type": "
|
| 144 |
"method": null,
|
| 145 |
-
"output_summary": "
|
| 146 |
-
"success":
|
| 147 |
-
"quality_score":
|
| 148 |
-
"resource_cost":
|
| 149 |
"time_cost_days": 0.5
|
| 150 |
},
|
| 151 |
{
|
|
@@ -160,26 +160,6 @@
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"step_index": 14,
|
| 163 |
-
"action_type": "design_followup_experiment",
|
| 164 |
-
"method": null,
|
| 165 |
-
"output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
|
| 166 |
-
"success": true,
|
| 167 |
-
"quality_score": 0.28,
|
| 168 |
-
"resource_cost": 100.0,
|
| 169 |
-
"time_cost_days": 0.5
|
| 170 |
-
},
|
| 171 |
-
{
|
| 172 |
-
"step_index": 15,
|
| 173 |
-
"action_type": "synthesize_conclusion",
|
| 174 |
-
"method": null,
|
| 175 |
-
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
| 176 |
-
"success": false,
|
| 177 |
-
"quality_score": 1.0,
|
| 178 |
-
"resource_cost": 0.0,
|
| 179 |
-
"time_cost_days": 0.5
|
| 180 |
-
},
|
| 181 |
-
{
|
| 182 |
-
"step_index": 16,
|
| 183 |
"action_type": "synthesize_conclusion",
|
| 184 |
"method": null,
|
| 185 |
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
|
@@ -222,8 +202,8 @@
|
|
| 222 |
"Cannot synthesise conclusion before inferring pathways or mechanisms"
|
| 223 |
],
|
| 224 |
"uncertainty_summary": {
|
| 225 |
-
"avg_uncertainty": 0.
|
| 226 |
-
"avg_quality": 0.
|
| 227 |
},
|
| 228 |
"reward_breakdown": {
|
| 229 |
"validity": -1.0,
|
|
@@ -249,92 +229,83 @@
|
|
| 249 |
"latent": {
|
| 250 |
"cell_populations": [
|
| 251 |
{
|
| 252 |
-
"name": "
|
| 253 |
-
"proportion": 0.
|
| 254 |
"marker_genes": [
|
| 255 |
-
"
|
| 256 |
-
"
|
| 257 |
-
"
|
| 258 |
],
|
| 259 |
-
"state": "
|
| 260 |
},
|
| 261 |
{
|
| 262 |
-
"name": "
|
| 263 |
-
"proportion": 0.
|
| 264 |
"marker_genes": [
|
| 265 |
-
"
|
| 266 |
-
"
|
| 267 |
-
"
|
| 268 |
],
|
| 269 |
"state": "normal"
|
| 270 |
},
|
| 271 |
{
|
| 272 |
-
"name": "
|
| 273 |
-
"proportion": 0.
|
| 274 |
"marker_genes": [
|
| 275 |
-
"
|
| 276 |
-
"
|
| 277 |
-
"
|
| 278 |
],
|
| 279 |
"state": "quiescent"
|
| 280 |
},
|
| 281 |
{
|
| 282 |
-
"name": "
|
| 283 |
-
"proportion": 0.
|
| 284 |
-
"marker_genes": [
|
| 285 |
-
"CD3D",
|
| 286 |
-
"CD3E",
|
| 287 |
-
"CD8A"
|
| 288 |
-
],
|
| 289 |
-
"state": "activated"
|
| 290 |
-
},
|
| 291 |
-
{
|
| 292 |
-
"name": "macrophage",
|
| 293 |
-
"proportion": 0.096,
|
| 294 |
"marker_genes": [
|
| 295 |
-
"
|
| 296 |
-
"
|
| 297 |
-
"
|
| 298 |
],
|
| 299 |
-
"state": "
|
| 300 |
},
|
| 301 |
{
|
| 302 |
-
"name": "
|
| 303 |
-
"proportion": 0.
|
| 304 |
"marker_genes": [
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
],
|
| 309 |
-
"state": "
|
| 310 |
}
|
| 311 |
],
|
| 312 |
"true_markers": [
|
| 313 |
-
"
|
| 314 |
-
"
|
| 315 |
-
"
|
| 316 |
-
"
|
| 317 |
],
|
| 318 |
"causal_mechanisms": [
|
| 319 |
-
"
|
| 320 |
-
"
|
|
|
|
| 321 |
],
|
| 322 |
"true_pathways": {
|
| 323 |
-
"
|
| 324 |
-
"
|
| 325 |
-
"
|
| 326 |
-
"
|
| 327 |
-
"
|
| 328 |
},
|
| 329 |
-
"true_de_genes_count":
|
| 330 |
-
"true_regulatory_network_size":
|
| 331 |
"confounders": {},
|
| 332 |
-
"n_true_cells":
|
| 333 |
"technical": {
|
| 334 |
-
"ambient_rna_fraction": 0.
|
| 335 |
-
"doublet_rate": 0.
|
| 336 |
-
"dropout_rate": 0.
|
| 337 |
-
"sample_quality": 0.
|
| 338 |
"library_complexity": 0.8,
|
| 339 |
"capture_efficiency": 0.6
|
| 340 |
},
|
|
@@ -359,8 +330,8 @@
|
|
| 359 |
"followup_designed": true,
|
| 360 |
"subagent_review_requested": false,
|
| 361 |
"conclusion_reached": false,
|
| 362 |
-
"n_cells_sequenced":
|
| 363 |
-
"n_cells_after_filter":
|
| 364 |
"n_clusters_found": "7",
|
| 365 |
"n_de_genes_found": null,
|
| 366 |
"n_markers_found": null
|
|
|
|
| 1 |
{
|
| 2 |
+
"timestamp": 1772994420.314485,
|
| 3 |
"step": 20,
|
| 4 |
"episode_done": false,
|
| 5 |
+
"cumulative_reward": 1.8024931955050638,
|
| 6 |
+
"gen_time_s": 23.52,
|
| 7 |
+
"model_response_raw": "{\n\"action_type\": \"design_followup_experiment\",\n\"method\": \"\",\n\"parameters\": {\n\"disease_comparison_groups\": true,\n\"treatment_condition_focus\": false,\n\"criterion_for_target_genes\": \"Differentially Expressed Between Healthy vs AD Brain Cells\"\n},\n\"Justification\": \"The primary goal remains identifying DGEs across conditions rather than targeting specific known pathways first. While 'filter_data' was previously performed but failed due to lack of validated markers from previous attempts, re-running filtering based solely on technical/quality criteria would still leave ambiguous candidates if noise pe",
|
| 8 |
"model_thinking": "",
|
| 9 |
"thinking_enabled": false,
|
| 10 |
"task": {
|
| 11 |
+
"problem_statement": "Identify differentially expressed genes between Alzheimer's disease and healthy brain tissue using single-cell RNA sequencing.",
|
| 12 |
"organism": "human",
|
| 13 |
+
"tissue": "brain",
|
| 14 |
"modality": "scRNA-seq",
|
| 15 |
"conditions": [
|
| 16 |
"healthy",
|
| 17 |
+
"Alzheimer"
|
| 18 |
],
|
| 19 |
+
"budget_limit": 84032.3891133686,
|
| 20 |
+
"time_limit_days": 116.86792941835967
|
| 21 |
},
|
| 22 |
"resources": {
|
| 23 |
"budget_used": 28900.0,
|
| 24 |
+
"budget_remaining": 55132.39,
|
| 25 |
"time_used_days": 18.0,
|
| 26 |
+
"time_remaining_days": 98.9,
|
| 27 |
"samples_consumed": 0,
|
| 28 |
"compute_hours_used": 12.0
|
| 29 |
},
|
|
|
|
| 32 |
"step_index": 1,
|
| 33 |
"action_type": "collect_sample",
|
| 34 |
"method": null,
|
| 35 |
+
"output_summary": "Collected 6 samples (quality=0.56)",
|
| 36 |
"success": true,
|
| 37 |
+
"quality_score": 0.562,
|
| 38 |
"resource_cost": 5000.0,
|
| 39 |
"time_cost_days": 7.0
|
| 40 |
},
|
|
|
|
| 62 |
"step_index": 4,
|
| 63 |
"action_type": "prepare_library",
|
| 64 |
"method": null,
|
| 65 |
+
"output_summary": "Library prepared (complexity=0.75)",
|
| 66 |
"success": true,
|
| 67 |
+
"quality_score": 0.751,
|
| 68 |
"resource_cost": 8000.0,
|
| 69 |
"time_cost_days": 3.0
|
| 70 |
},
|
|
|
|
| 72 |
"step_index": 5,
|
| 73 |
"action_type": "sequence_cells",
|
| 74 |
"method": null,
|
| 75 |
+
"output_summary": "Sequenced 4574 cells, 9256 genes detected, median UMI=3083",
|
| 76 |
"success": true,
|
| 77 |
+
"quality_score": 0.47,
|
| 78 |
"resource_cost": 15000.0,
|
| 79 |
"time_cost_days": 5.0
|
| 80 |
},
|
| 81 |
{
|
| 82 |
"step_index": 6,
|
| 83 |
+
"action_type": "filter_data",
|
| 84 |
+
"method": null,
|
| 85 |
+
"output_summary": "Action blocked: Cannot filter data before QC",
|
| 86 |
+
"success": false,
|
| 87 |
+
"quality_score": 1.0,
|
| 88 |
+
"resource_cost": 50.0,
|
| 89 |
+
"time_cost_days": 0.2
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"step_index": 7,
|
| 93 |
"action_type": "run_qc",
|
| 94 |
"method": null,
|
| 95 |
"output_summary": "QC metrics computed",
|
| 96 |
"success": true,
|
| 97 |
+
"quality_score": 0.829,
|
| 98 |
"resource_cost": 100.0,
|
| 99 |
"time_cost_days": 0.5
|
| 100 |
},
|
| 101 |
{
|
| 102 |
+
"step_index": 8,
|
| 103 |
"action_type": "filter_data",
|
| 104 |
"method": null,
|
| 105 |
+
"output_summary": "Filtered 4574 \u2192 3996 cells (87% retained)",
|
| 106 |
"success": true,
|
| 107 |
+
"quality_score": 0.874,
|
| 108 |
"resource_cost": 50.0,
|
| 109 |
"time_cost_days": 0.2
|
| 110 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
{
|
| 112 |
"step_index": 9,
|
| 113 |
"action_type": "normalize_data",
|
|
|
|
| 120 |
},
|
| 121 |
{
|
| 122 |
"step_index": 10,
|
| 123 |
+
"action_type": "cluster_cells",
|
| 124 |
"method": null,
|
| 125 |
+
"output_summary": "Found 7 clusters",
|
| 126 |
+
"success": true,
|
| 127 |
+
"quality_score": 0.79,
|
| 128 |
"resource_cost": 100.0,
|
| 129 |
"time_cost_days": 0.5
|
| 130 |
},
|
|
|
|
| 132 |
"step_index": 11,
|
| 133 |
"action_type": "design_followup_experiment",
|
| 134 |
"method": null,
|
| 135 |
+
"output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
|
| 136 |
+
"success": true,
|
| 137 |
+
"quality_score": 0.28,
|
| 138 |
"resource_cost": 100.0,
|
| 139 |
"time_cost_days": 0.5
|
| 140 |
},
|
| 141 |
{
|
| 142 |
"step_index": 12,
|
| 143 |
+
"action_type": "synthesize_conclusion",
|
| 144 |
"method": null,
|
| 145 |
+
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
| 146 |
+
"success": false,
|
| 147 |
+
"quality_score": 1.0,
|
| 148 |
+
"resource_cost": 0.0,
|
| 149 |
"time_cost_days": 0.5
|
| 150 |
},
|
| 151 |
{
|
|
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"step_index": 14,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
"action_type": "synthesize_conclusion",
|
| 164 |
"method": null,
|
| 165 |
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
|
|
|
| 202 |
"Cannot synthesise conclusion before inferring pathways or mechanisms"
|
| 203 |
],
|
| 204 |
"uncertainty_summary": {
|
| 205 |
+
"avg_uncertainty": 0.224,
|
| 206 |
+
"avg_quality": 0.814
|
| 207 |
},
|
| 208 |
"reward_breakdown": {
|
| 209 |
"validity": -1.0,
|
|
|
|
| 229 |
"latent": {
|
| 230 |
"cell_populations": [
|
| 231 |
{
|
| 232 |
+
"name": "excitatory_neuron",
|
| 233 |
+
"proportion": 0.349,
|
| 234 |
"marker_genes": [
|
| 235 |
+
"SLC17A7",
|
| 236 |
+
"CAMK2A",
|
| 237 |
+
"NRGN"
|
| 238 |
],
|
| 239 |
+
"state": "stressed"
|
| 240 |
},
|
| 241 |
{
|
| 242 |
+
"name": "inhibitory_neuron",
|
| 243 |
+
"proportion": 0.209,
|
| 244 |
"marker_genes": [
|
| 245 |
+
"GAD1",
|
| 246 |
+
"GAD2",
|
| 247 |
+
"SLC32A1"
|
| 248 |
],
|
| 249 |
"state": "normal"
|
| 250 |
},
|
| 251 |
{
|
| 252 |
+
"name": "astrocyte",
|
| 253 |
+
"proportion": 0.211,
|
| 254 |
"marker_genes": [
|
| 255 |
+
"GFAP",
|
| 256 |
+
"AQP4",
|
| 257 |
+
"SLC1A3"
|
| 258 |
],
|
| 259 |
"state": "quiescent"
|
| 260 |
},
|
| 261 |
{
|
| 262 |
+
"name": "oligodendrocyte",
|
| 263 |
+
"proportion": 0.153,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
"marker_genes": [
|
| 265 |
+
"MBP",
|
| 266 |
+
"PLP1",
|
| 267 |
+
"MOG"
|
| 268 |
],
|
| 269 |
+
"state": "myelinating"
|
| 270 |
},
|
| 271 |
{
|
| 272 |
+
"name": "OPC",
|
| 273 |
+
"proportion": 0.078,
|
| 274 |
"marker_genes": [
|
| 275 |
+
"PDGFRA",
|
| 276 |
+
"CSPG4",
|
| 277 |
+
"OLIG2"
|
| 278 |
],
|
| 279 |
+
"state": "progenitor"
|
| 280 |
}
|
| 281 |
],
|
| 282 |
"true_markers": [
|
| 283 |
+
"TREM2",
|
| 284 |
+
"APOE",
|
| 285 |
+
"GFAP",
|
| 286 |
+
"C1QA"
|
| 287 |
],
|
| 288 |
"causal_mechanisms": [
|
| 289 |
+
"TREM2-mediated microglial activation in amyloid clearance",
|
| 290 |
+
"complement-driven synaptic pruning",
|
| 291 |
+
"reactive astrogliosis amplifying neuroinflammation"
|
| 292 |
],
|
| 293 |
"true_pathways": {
|
| 294 |
+
"complement_cascade": 0.839,
|
| 295 |
+
"neuroinflammation": 0.805,
|
| 296 |
+
"amyloid_processing": 0.666,
|
| 297 |
+
"synaptic_signalling": 0.394,
|
| 298 |
+
"lipid_metabolism": 0.674
|
| 299 |
},
|
| 300 |
+
"true_de_genes_count": 10,
|
| 301 |
+
"true_regulatory_network_size": 0,
|
| 302 |
"confounders": {},
|
| 303 |
+
"n_true_cells": 7619,
|
| 304 |
"technical": {
|
| 305 |
+
"ambient_rna_fraction": 0.04108598341080635,
|
| 306 |
+
"doublet_rate": 0.045763110874719674,
|
| 307 |
+
"dropout_rate": 0.07138299827651534,
|
| 308 |
+
"sample_quality": 0.9242864922806572,
|
| 309 |
"library_complexity": 0.8,
|
| 310 |
"capture_efficiency": 0.6
|
| 311 |
},
|
|
|
|
| 330 |
"followup_designed": true,
|
| 331 |
"subagent_review_requested": false,
|
| 332 |
"conclusion_reached": false,
|
| 333 |
+
"n_cells_sequenced": 4574,
|
| 334 |
+
"n_cells_after_filter": 3996,
|
| 335 |
"n_clusters_found": "7",
|
| 336 |
"n_de_genes_found": null,
|
| 337 |
"n_markers_found": null
|
colab_train_llama32_remote.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal Colab entrypoint for Unsloth GRPO against a remote OpenEnv Space.
|
| 2 |
+
|
| 3 |
+
This keeps the repo's prompt formatting and action parsing logic, but builds
|
| 4 |
+
prompt states by interacting with a deployed OpenEnv Hugging Face Space instead
|
| 5 |
+
of the local in-process environment. That makes the Colab workflow match the
|
| 6 |
+
remote environment users actually want to train against.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 15 |
+
|
| 16 |
+
from client import BioExperimentEnv
|
| 17 |
+
import training_script as base
|
| 18 |
+
|
| 19 |
+
DEFAULT_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
|
| 20 |
+
DEFAULT_OUTPUT_DIR = "artifacts/grpo-unsloth-llama32-3b-space"
|
| 21 |
+
DEFAULT_SPACE_REPO_ID = "Ev3Dev/hackathon"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def hf_space_repo_to_base_url(repo_id: str) -> str:
|
| 25 |
+
"""Convert `owner/space-name` to the standard `hf.space` URL."""
|
| 26 |
+
owner, space_name = repo_id.split("/", 1)
|
| 27 |
+
normalized_owner = owner.strip().lower().replace("_", "-")
|
| 28 |
+
normalized_space = space_name.strip().lower().replace("_", "-")
|
| 29 |
+
return f"https://{normalized_owner}-{normalized_space}.hf.space"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def require_unsloth_base():
|
| 33 |
+
# Unsloth must be imported before trl / transformers / peft.
|
| 34 |
+
import unsloth # noqa: F401
|
| 35 |
+
import training_unsloth as unsloth_base
|
| 36 |
+
|
| 37 |
+
return unsloth_base
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_argument_parser() -> argparse.ArgumentParser:
|
| 41 |
+
parser = argparse.ArgumentParser(
|
| 42 |
+
description="Train Unsloth Llama 3.2 3B on a remote OpenEnv Hugging Face Space."
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
|
| 45 |
+
parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
|
| 46 |
+
parser.add_argument("--dataset-episodes", type=int, default=8)
|
| 47 |
+
parser.add_argument("--rollout-steps", type=int, default=6)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--collection-policy",
|
| 50 |
+
choices=["random", "heuristic"],
|
| 51 |
+
default="heuristic",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--base-url", default="")
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--space-repo-id",
|
| 56 |
+
default=DEFAULT_SPACE_REPO_ID,
|
| 57 |
+
help="Hugging Face Space repo id, for example `Ev3Dev/hackathon`.",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument("--num-generations", type=int, default=2)
|
| 60 |
+
parser.add_argument("--max-completion-length", type=int, default=160)
|
| 61 |
+
parser.add_argument("--max-prompt-length", type=int, default=1280)
|
| 62 |
+
parser.add_argument("--max-seq-length", type=int, default=2048)
|
| 63 |
+
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
|
| 64 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 65 |
+
parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| 66 |
+
parser.add_argument("--num-train-epochs", type=float, default=1.0)
|
| 67 |
+
parser.add_argument("--logging-steps", type=int, default=1)
|
| 68 |
+
parser.add_argument("--save-steps", type=int, default=25)
|
| 69 |
+
parser.add_argument("--plot-metric-key", default=None)
|
| 70 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 71 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 72 |
+
parser.add_argument("--load-model-only", action="store_true")
|
| 73 |
+
parser.add_argument("--trust-remote-code", action="store_true")
|
| 74 |
+
parser.add_argument("--disable-4bit", action="store_true")
|
| 75 |
+
parser.add_argument("--lora-r", type=int, default=unsloth_defaults()["lora_r"])
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--lora-alpha", type=int, default=unsloth_defaults()["lora_alpha"]
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--lora-dropout", type=float, default=unsloth_defaults()["lora_dropout"]
|
| 81 |
+
)
|
| 82 |
+
return parser
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def unsloth_defaults() -> Dict[str, float]:
|
| 86 |
+
return {
|
| 87 |
+
"lora_r": 16,
|
| 88 |
+
"lora_alpha": 16,
|
| 89 |
+
"lora_dropout": 0.0,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| 94 |
+
args = build_argument_parser().parse_args(argv)
|
| 95 |
+
if not args.base_url:
|
| 96 |
+
args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
|
| 97 |
+
return args
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def make_training_args(**overrides: Any) -> argparse.Namespace:
|
| 101 |
+
parser = build_argument_parser()
|
| 102 |
+
defaults = vars(parser.parse_args([]))
|
| 103 |
+
unknown = sorted(set(overrides) - set(defaults))
|
| 104 |
+
if unknown:
|
| 105 |
+
raise ValueError(f"Unknown training args: {', '.join(unknown)}")
|
| 106 |
+
defaults.update(overrides)
|
| 107 |
+
args = argparse.Namespace(**defaults)
|
| 108 |
+
if not getattr(args, "base_url", ""):
|
| 109 |
+
args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
|
| 110 |
+
return args
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_remote_prompt_examples(args: argparse.Namespace) -> List[Dict[str, str]]:
|
| 114 |
+
"""Collect prompt states directly from the remote OpenEnv server."""
|
| 115 |
+
rng = random.Random(args.seed)
|
| 116 |
+
examples: List[Dict[str, str]] = []
|
| 117 |
+
|
| 118 |
+
for _episode_idx in range(args.dataset_episodes):
|
| 119 |
+
with BioExperimentEnv(base_url=args.base_url) as env:
|
| 120 |
+
result = env.reset()
|
| 121 |
+
obs = result.observation
|
| 122 |
+
history_actions: List[base.ExperimentAction] = []
|
| 123 |
+
|
| 124 |
+
for step_idx in range(args.rollout_steps):
|
| 125 |
+
if obs.done:
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
next_action = base.build_experiment_action(
|
| 129 |
+
action_type=base.pick_action(
|
| 130 |
+
args.collection_policy,
|
| 131 |
+
step_idx,
|
| 132 |
+
[action.action_type for action in history_actions],
|
| 133 |
+
),
|
| 134 |
+
discovered_markers=obs.discovered_markers,
|
| 135 |
+
candidate_mechanisms=obs.candidate_mechanisms,
|
| 136 |
+
conditions=obs.task.conditions,
|
| 137 |
+
)
|
| 138 |
+
examples.append(
|
| 139 |
+
{
|
| 140 |
+
"prompt": base.build_training_prompt(obs),
|
| 141 |
+
"history_actions": json.dumps(
|
| 142 |
+
[action.model_dump() for action in history_actions]
|
| 143 |
+
),
|
| 144 |
+
"reference_action": base.action_completion_json(next_action),
|
| 145 |
+
"problem_statement": obs.task.problem_statement,
|
| 146 |
+
"episode_tag": f"remote-{rng.randrange(10**9):09d}",
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
history_actions.append(next_action)
|
| 151 |
+
result = env.step(next_action)
|
| 152 |
+
obs = result.observation
|
| 153 |
+
if result.done:
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
return examples
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class RemoteSpaceReward:
|
| 160 |
+
"""Reward function that replays each candidate against the remote Space."""
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
*,
|
| 165 |
+
base_url: str,
|
| 166 |
+
invalid_action_penalty: float = base.INVALID_ACTION_PENALTY,
|
| 167 |
+
environment_error_penalty: float = base.ENVIRONMENT_ERROR_PENALTY,
|
| 168 |
+
) -> None:
|
| 169 |
+
self.__name__ = "remote_space_reward"
|
| 170 |
+
self.base_url = base_url
|
| 171 |
+
self.invalid_action_penalty = invalid_action_penalty
|
| 172 |
+
self.environment_error_penalty = environment_error_penalty
|
| 173 |
+
|
| 174 |
+
def __call__(
|
| 175 |
+
self,
|
| 176 |
+
completions: List[Any],
|
| 177 |
+
history_actions: Optional[List[str]] = None,
|
| 178 |
+
**_: Any,
|
| 179 |
+
) -> List[float]:
|
| 180 |
+
history_columns = base.normalise_column(history_actions, len(completions))
|
| 181 |
+
rewards: List[float] = []
|
| 182 |
+
|
| 183 |
+
for completion, current_history in zip(completions, history_columns):
|
| 184 |
+
action = base.parse_action_completion(base.completion_to_text(completion))
|
| 185 |
+
if action is None:
|
| 186 |
+
rewards.append(self.invalid_action_penalty)
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
rewards.append(self._score_remote(action, current_history))
|
| 191 |
+
except Exception:
|
| 192 |
+
rewards.append(self.environment_error_penalty)
|
| 193 |
+
|
| 194 |
+
return rewards
|
| 195 |
+
|
| 196 |
+
def _score_remote(
|
| 197 |
+
self,
|
| 198 |
+
action: base.ExperimentAction,
|
| 199 |
+
history_actions: Optional[str],
|
| 200 |
+
) -> float:
|
| 201 |
+
with BioExperimentEnv(base_url=self.base_url) as env:
|
| 202 |
+
result = env.reset()
|
| 203 |
+
obs = result.observation
|
| 204 |
+
|
| 205 |
+
for previous_action in base.decode_history_actions(history_actions):
|
| 206 |
+
result = env.step(previous_action)
|
| 207 |
+
obs = result.observation
|
| 208 |
+
if result.done:
|
| 209 |
+
return float(result.reward or obs.reward or 0.0)
|
| 210 |
+
|
| 211 |
+
action = base.ensure_conclusion_claims(obs, action)
|
| 212 |
+
result = env.step(action)
|
| 213 |
+
if result.reward is not None:
|
| 214 |
+
return float(result.reward)
|
| 215 |
+
return float(result.observation.reward)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def run_dry_run_preview(
|
| 219 |
+
examples: Sequence[Dict[str, str]],
|
| 220 |
+
reward_fn: RemoteSpaceReward,
|
| 221 |
+
output_dir: str,
|
| 222 |
+
base_url: str,
|
| 223 |
+
) -> None:
|
| 224 |
+
if not examples:
|
| 225 |
+
raise ValueError("No training prompts were generated for the dry run.")
|
| 226 |
+
|
| 227 |
+
sample = examples[0]
|
| 228 |
+
sample_reward = reward_fn(
|
| 229 |
+
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
|
| 230 |
+
history_actions=[sample["history_actions"]],
|
| 231 |
+
)[0]
|
| 232 |
+
|
| 233 |
+
print(f"Built {len(examples)} remote prompt states.")
|
| 234 |
+
print(f"Remote OpenEnv Space: {base_url}")
|
| 235 |
+
print(f"Output directory: {output_dir}")
|
| 236 |
+
print(f"Sample reward for reference action: {sample_reward:+.3f}")
|
| 237 |
+
print("\nSample prompt:\n")
|
| 238 |
+
print(sample["prompt"])
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
| 242 |
+
random.seed(args.seed)
|
| 243 |
+
runtime = base.resolve_torch_runtime()
|
| 244 |
+
unsloth_base = require_unsloth_base()
|
| 245 |
+
|
| 246 |
+
if args.load_model_only:
|
| 247 |
+
tokenizer, model = unsloth_base.load_model_artifacts(
|
| 248 |
+
args.model_id,
|
| 249 |
+
trust_remote_code=args.trust_remote_code,
|
| 250 |
+
max_seq_length=args.max_seq_length,
|
| 251 |
+
load_in_4bit=not args.disable_4bit,
|
| 252 |
+
fast_inference=False,
|
| 253 |
+
prepare_for_inference=True,
|
| 254 |
+
)
|
| 255 |
+
return {
|
| 256 |
+
"args": args,
|
| 257 |
+
"runtime": runtime,
|
| 258 |
+
"tokenizer": tokenizer,
|
| 259 |
+
"model": model,
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
examples = build_remote_prompt_examples(args)
|
| 263 |
+
reward_fn = RemoteSpaceReward(base_url=args.base_url)
|
| 264 |
+
|
| 265 |
+
if args.dry_run:
|
| 266 |
+
run_dry_run_preview(examples, reward_fn, args.output_dir, args.base_url)
|
| 267 |
+
return {
|
| 268 |
+
"args": args,
|
| 269 |
+
"runtime": runtime,
|
| 270 |
+
"examples": examples,
|
| 271 |
+
"reward_fn": reward_fn,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
from datasets import Dataset
|
| 275 |
+
|
| 276 |
+
FastLanguageModel = unsloth_base.patch_unsloth_grpo()
|
| 277 |
+
train_dataset = Dataset.from_list(examples)
|
| 278 |
+
|
| 279 |
+
tokenizer, model = unsloth_base.load_model_artifacts(
|
| 280 |
+
args.model_id,
|
| 281 |
+
trust_remote_code=args.trust_remote_code,
|
| 282 |
+
max_seq_length=args.max_seq_length,
|
| 283 |
+
load_in_4bit=not args.disable_4bit,
|
| 284 |
+
fast_inference=False,
|
| 285 |
+
)
|
| 286 |
+
model = unsloth_base.apply_lora_adapters(FastLanguageModel, model, args)
|
| 287 |
+
|
| 288 |
+
print(
|
| 289 |
+
f"Training runtime: device={runtime['device']} "
|
| 290 |
+
f"name={runtime['device_name']} "
|
| 291 |
+
f"dtype={runtime['dtype']} "
|
| 292 |
+
f"load_in_4bit={not args.disable_4bit}"
|
| 293 |
+
)
|
| 294 |
+
print(f"Remote OpenEnv Space: {args.base_url}")
|
| 295 |
+
print(f"Collected remote prompt states: {len(examples)}")
|
| 296 |
+
|
| 297 |
+
trainer = unsloth_base.build_unsloth_grpo_trainer(
|
| 298 |
+
model=model,
|
| 299 |
+
tokenizer=tokenizer,
|
| 300 |
+
reward_func=reward_fn,
|
| 301 |
+
train_dataset=train_dataset,
|
| 302 |
+
args=args,
|
| 303 |
+
runtime=runtime,
|
| 304 |
+
)
|
| 305 |
+
for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
|
| 306 |
+
if not hasattr(trainer, attr):
|
| 307 |
+
setattr(trainer, attr, None)
|
| 308 |
+
|
| 309 |
+
trainer.train()
|
| 310 |
+
trainer.save_model(args.output_dir)
|
| 311 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 312 |
+
|
| 313 |
+
plot_paths = base.save_training_plots(
|
| 314 |
+
trainer.state.log_history,
|
| 315 |
+
args.output_dir,
|
| 316 |
+
metric_key=args.plot_metric_key,
|
| 317 |
+
)
|
| 318 |
+
print("Saved training plots:")
|
| 319 |
+
for plot_name, plot_path in plot_paths.items():
|
| 320 |
+
print(f" - {plot_name}: {plot_path}")
|
| 321 |
+
|
| 322 |
+
return {
|
| 323 |
+
"args": args,
|
| 324 |
+
"runtime": runtime,
|
| 325 |
+
"examples": examples,
|
| 326 |
+
"reward_fn": reward_fn,
|
| 327 |
+
"train_dataset": train_dataset,
|
| 328 |
+
"tokenizer": tokenizer,
|
| 329 |
+
"model": model,
|
| 330 |
+
"trainer": trainer,
|
| 331 |
+
"plot_paths": plot_paths,
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def main() -> None:
|
| 336 |
+
run_training(parse_args())
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == "__main__":
|
| 340 |
+
main()
|
colab_train_unsloth.ipynb
CHANGED
|
@@ -1,128 +1,347 @@
|
|
| 1 |
-
{
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
"
|
| 88 |
-
"
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Minimal Unsloth GRPO on Colab with a remote OpenEnv Space\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook is intentionally similar to the 2048 notebook pattern:\n",
|
| 10 |
+
"- training runs locally inside Colab\n",
|
| 11 |
+
"- the environment is accessed remotely through a Hugging Face Space\n",
|
| 12 |
+
"- the reward function is defined in notebook code by replaying actions against that remote env\n",
|
| 13 |
+
"- prompt / action / conclusion formatting mirrors the repo logic without importing the repo training script\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"Default remote env: `Ev3Dev/hackathon`\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"**Runtime**: Enable a GPU in Colab: Runtime -> Change runtime type -> GPU."
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "code",
|
| 22 |
+
"execution_count": null,
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [],
|
| 25 |
+
"source": [
|
| 26 |
+
"# 1. Clone the repo for lightweight client / model definitions only\n",
|
| 27 |
+
"REPO_URL = \"https://github.com/mhtruong1031/OpenENV-Hackathon.git\" # or your fork\n",
|
| 28 |
+
"REPO_DIR = \"OpenENV-Hackathon\"\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"!git clone --depth 1 {REPO_URL} {REPO_DIR}\n",
|
| 31 |
+
"%cd {REPO_DIR}"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"# 2. Install only the runtime pieces needed for notebook-side training\n",
|
| 41 |
+
"!pip install -q unsloth unsloth_zoo --no-deps\n",
|
| 42 |
+
"!pip install -q \"openenv-core[core]>=0.2.0\" \"pydantic>=2\" \"numpy>=1.24.0\" \"scipy>=1.10.0\" \"datasets>=4.6.1\" \"accelerate>=1.13.0\" \"peft>=0.15.0\" \"bitsandbytes>=0.45.0\" \"matplotlib>=3.8.0\"\n",
|
| 43 |
+
"!pip install -q \"transformers>=4.57.0\" \"trl>=0.29.0\" \"torchvision>=0.20.0\""
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"# 3. Import repo reward helpers, but keep the environment remote\n",
|
| 53 |
+
"import inspect\n",
|
| 54 |
+
"import json\n",
|
| 55 |
+
"import random\n",
|
| 56 |
+
"import sys\n",
|
| 57 |
+
"from pathlib import Path\n",
|
| 58 |
+
"from typing import Any, Dict, List\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# Unsloth must be imported before trl / transformers / peft.\n",
|
| 61 |
+
"import unsloth # noqa: F401\n",
|
| 62 |
+
"import torch\n",
|
| 63 |
+
"from unsloth import FastLanguageModel, PatchFastRL\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"sys.path.insert(0, str(Path.cwd()))\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"from client import BioExperimentEnv\n",
|
| 68 |
+
"from models import ActionType, ExperimentAction\n",
|
| 69 |
+
"from training_script import (\n",
|
| 70 |
+
" INVALID_ACTION_PENALTY,\n",
|
| 71 |
+
" ENVIRONMENT_ERROR_PENALTY,\n",
|
| 72 |
+
" OpenEnvReward,\n",
|
| 73 |
+
" build_training_prompt,\n",
|
| 74 |
+
" build_experiment_action,\n",
|
| 75 |
+
" decode_history_actions,\n",
|
| 76 |
+
" pick_action,\n",
|
| 77 |
+
" save_training_plots,\n",
|
| 78 |
+
")\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"MAX_COMPLETION_TOKENS = 160\n",
|
| 81 |
+
"LORA_TARGET_MODULES = [\n",
|
| 82 |
+
" \"q_proj\",\n",
|
| 83 |
+
" \"k_proj\",\n",
|
| 84 |
+
" \"v_proj\",\n",
|
| 85 |
+
" \"o_proj\",\n",
|
| 86 |
+
" \"gate_proj\",\n",
|
| 87 |
+
" \"up_proj\",\n",
|
| 88 |
+
" \"down_proj\",\n",
|
| 89 |
+
"]\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"def hf_space_repo_to_base_url(repo_id: str) -> str:\n",
|
| 93 |
+
" owner, space_name = repo_id.split(\"/\", 1)\n",
|
| 94 |
+
" return f\"https://{owner.lower().replace('_', '-')}-{space_name.lower().replace('_', '-')}.hf.space\"\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"def build_remote_prompt_examples(\n",
|
| 98 |
+
" base_url: str,\n",
|
| 99 |
+
" dataset_episodes: int,\n",
|
| 100 |
+
" rollout_steps: int,\n",
|
| 101 |
+
" seed: int,\n",
|
| 102 |
+
") -> List[Dict[str, str]]:\n",
|
| 103 |
+
" rng = random.Random(seed)\n",
|
| 104 |
+
" examples: List[Dict[str, str]] = []\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" for _ in range(dataset_episodes):\n",
|
| 107 |
+
" with BioExperimentEnv(base_url=base_url) as env:\n",
|
| 108 |
+
" result = env.reset()\n",
|
| 109 |
+
" obs = result.observation\n",
|
| 110 |
+
" history_actions: List[ExperimentAction] = []\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" for step_idx in range(rollout_steps):\n",
|
| 113 |
+
" if obs.done:\n",
|
| 114 |
+
" break\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" next_action = build_experiment_action(\n",
|
| 117 |
+
" action_type=pick_action(\n",
|
| 118 |
+
" \"heuristic\",\n",
|
| 119 |
+
" step_idx,\n",
|
| 120 |
+
" [action.action_type for action in history_actions],\n",
|
| 121 |
+
" ),\n",
|
| 122 |
+
" discovered_markers=obs.discovered_markers,\n",
|
| 123 |
+
" candidate_mechanisms=obs.candidate_mechanisms,\n",
|
| 124 |
+
" conditions=obs.task.conditions,\n",
|
| 125 |
+
" )\n",
|
| 126 |
+
" examples.append(\n",
|
| 127 |
+
" {\n",
|
| 128 |
+
" \"prompt\": build_training_prompt(obs),\n",
|
| 129 |
+
" \"history_actions\": json.dumps(\n",
|
| 130 |
+
" [action.model_dump() for action in history_actions]\n",
|
| 131 |
+
" ),\n",
|
| 132 |
+
" \"reference_action\": json.dumps(next_action.model_dump()),\n",
|
| 133 |
+
" \"problem_statement\": obs.task.problem_statement,\n",
|
| 134 |
+
" \"episode_tag\": f\"remote-{rng.randrange(10**9):09d}\",\n",
|
| 135 |
+
" }\n",
|
| 136 |
+
" )\n",
|
| 137 |
+
"\n",
|
| 138 |
+
" history_actions.append(next_action)\n",
|
| 139 |
+
" result = env.step(next_action)\n",
|
| 140 |
+
" obs = result.observation\n",
|
| 141 |
+
" if result.done:\n",
|
| 142 |
+
" break\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" return examples\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"def build_grpo_config(**overrides: Any):\n",
|
| 148 |
+
" from trl import GRPOConfig\n",
|
| 149 |
+
"\n",
|
| 150 |
+
" supported = set(inspect.signature(GRPOConfig.__init__).parameters)\n",
|
| 151 |
+
" config_kwargs = {\n",
|
| 152 |
+
" \"output_dir\": overrides[\"output_dir\"],\n",
|
| 153 |
+
" \"learning_rate\": overrides[\"learning_rate\"],\n",
|
| 154 |
+
" \"per_device_train_batch_size\": overrides[\"per_device_train_batch_size\"],\n",
|
| 155 |
+
" \"gradient_accumulation_steps\": overrides[\"gradient_accumulation_steps\"],\n",
|
| 156 |
+
" \"num_generations\": overrides[\"num_generations\"],\n",
|
| 157 |
+
" \"max_completion_length\": overrides[\"max_completion_length\"],\n",
|
| 158 |
+
" \"num_train_epochs\": overrides[\"num_train_epochs\"],\n",
|
| 159 |
+
" \"logging_steps\": overrides[\"logging_steps\"],\n",
|
| 160 |
+
" \"save_steps\": overrides[\"save_steps\"],\n",
|
| 161 |
+
" \"bf16\": overrides[\"bf16\"],\n",
|
| 162 |
+
" \"fp16\": overrides[\"fp16\"],\n",
|
| 163 |
+
" \"report_to\": \"none\",\n",
|
| 164 |
+
" \"remove_unused_columns\": False,\n",
|
| 165 |
+
" }\n",
|
| 166 |
+
" # Keep prompt truncation enabled. Leaving this as None can trigger\n",
|
| 167 |
+
" # an Unsloth rotary-cache shape mismatch on long GRPO prompts.\n",
|
| 168 |
+
" if \"max_prompt_length\" in supported:\n",
|
| 169 |
+
" config_kwargs[\"max_prompt_length\"] = overrides[\"max_prompt_length\"]\n",
|
| 170 |
+
" if (\n",
|
| 171 |
+
" \"max_length\" in supported\n",
|
| 172 |
+
" and \"max_prompt_length\" not in supported\n",
|
| 173 |
+
" and \"max_completion_length\" not in supported\n",
|
| 174 |
+
" ):\n",
|
| 175 |
+
" config_kwargs[\"max_length\"] = (\n",
|
| 176 |
+
" overrides[\"max_prompt_length\"] + overrides[\"max_completion_length\"]\n",
|
| 177 |
+
" )\n",
|
| 178 |
+
" return GRPOConfig(**{k: v for k, v in config_kwargs.items() if k in supported})\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"print(\"CUDA:\", torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"\")\n",
|
| 182 |
+
"Path(\"artifacts\").mkdir(exist_ok=True)"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": null,
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [],
|
| 190 |
+
"source": [
|
| 191 |
+
"# 4. Config + collect prompt states from the remote Space\n",
|
| 192 |
+
"SPACE_REPO_ID = \"Ev3Dev/hackathon\"\n",
|
| 193 |
+
"SPACE_BASE_URL = hf_space_repo_to_base_url(SPACE_REPO_ID)\n",
|
| 194 |
+
"# If your Space has a custom domain, replace SPACE_BASE_URL manually.\n",
|
| 195 |
+
"\n",
|
| 196 |
+
"MODEL_ID = \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\"\n",
|
| 197 |
+
"OUTPUT_DIR = \"artifacts/grpo-unsloth-llama32-3b-remote-space\"\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"DATASET_EPISODES = 8\n",
|
| 200 |
+
"ROLLOUT_STEPS = 6\n",
|
| 201 |
+
"NUM_GENERATIONS = 2\n",
|
| 202 |
+
"# Keep this modest for Unsloth GRPO stability on Colab.\n",
|
| 203 |
+
"MAX_PROMPT_LENGTH = 768\n",
|
| 204 |
+
"MAX_SEQ_LENGTH = 2048\n",
|
| 205 |
+
"PER_DEVICE_TRAIN_BATCH_SIZE = 1\n",
|
| 206 |
+
"GRADIENT_ACCUMULATION_STEPS = 4\n",
|
| 207 |
+
"LEARNING_RATE = 5e-6\n",
|
| 208 |
+
"NUM_TRAIN_EPOCHS = 1.0\n",
|
| 209 |
+
"LOGGING_STEPS = 1\n",
|
| 210 |
+
"SAVE_STEPS = 25\n",
|
| 211 |
+
"SEED = 42\n",
|
| 212 |
+
"LORA_R = 16\n",
|
| 213 |
+
"LORA_ALPHA = 16\n",
|
| 214 |
+
"LORA_DROPOUT = 0.0\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"examples = build_remote_prompt_examples(\n",
|
| 217 |
+
" base_url=SPACE_BASE_URL,\n",
|
| 218 |
+
" dataset_episodes=DATASET_EPISODES,\n",
|
| 219 |
+
" rollout_steps=ROLLOUT_STEPS,\n",
|
| 220 |
+
" seed=SEED,\n",
|
| 221 |
+
")\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"reward_fn = OpenEnvReward(\n",
|
| 224 |
+
" reward_backend=\"remote\",\n",
|
| 225 |
+
" base_url=SPACE_BASE_URL,\n",
|
| 226 |
+
" invalid_action_penalty=INVALID_ACTION_PENALTY,\n",
|
| 227 |
+
" environment_error_penalty=ENVIRONMENT_ERROR_PENALTY,\n",
|
| 228 |
+
")\n",
|
| 229 |
+
"\n",
|
| 230 |
+
"print(\"Remote env:\", SPACE_BASE_URL)\n",
|
| 231 |
+
"print(\"Prompt states:\", len(examples))\n",
|
| 232 |
+
"print(\"Sample prompt preview:\\n\")\n",
|
| 233 |
+
"print(examples[0][\"prompt\"][:2000])"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "code",
|
| 238 |
+
"execution_count": null,
|
| 239 |
+
"metadata": {},
|
| 240 |
+
"outputs": [],
|
| 241 |
+
"source": [
|
| 242 |
+
"# 5. Local GRPO training in Colab, remote env for rewards\n",
|
| 243 |
+
"from datasets import Dataset\n",
|
| 244 |
+
"from trl import GRPOTrainer\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"PatchFastRL(\"GRPO\", FastLanguageModel)\n",
|
| 247 |
+
"train_dataset = Dataset.from_list(examples)\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"bf16 = bool(getattr(torch.cuda, \"is_bf16_supported\", lambda: False)()) if torch.cuda.is_available() else False\n",
|
| 250 |
+
"runtime_dtype = torch.bfloat16 if bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 253 |
+
" model_name=MODEL_ID,\n",
|
| 254 |
+
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
| 255 |
+
" dtype=runtime_dtype,\n",
|
| 256 |
+
" load_in_4bit=True,\n",
|
| 257 |
+
")\n",
|
| 258 |
+
"if tokenizer.pad_token is None and tokenizer.eos_token is not None:\n",
|
| 259 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 262 |
+
" model,\n",
|
| 263 |
+
" r=LORA_R,\n",
|
| 264 |
+
" target_modules=LORA_TARGET_MODULES,\n",
|
| 265 |
+
" lora_alpha=LORA_ALPHA,\n",
|
| 266 |
+
" lora_dropout=LORA_DROPOUT,\n",
|
| 267 |
+
" bias=\"none\",\n",
|
| 268 |
+
" use_gradient_checkpointing=True,\n",
|
| 269 |
+
" random_state=SEED,\n",
|
| 270 |
+
")\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"training_args = build_grpo_config(\n",
|
| 273 |
+
" output_dir=OUTPUT_DIR,\n",
|
| 274 |
+
" learning_rate=LEARNING_RATE,\n",
|
| 275 |
+
" per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,\n",
|
| 276 |
+
" gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
|
| 277 |
+
" num_generations=NUM_GENERATIONS,\n",
|
| 278 |
+
" max_completion_length=MAX_COMPLETION_TOKENS,\n",
|
| 279 |
+
" max_prompt_length=MAX_PROMPT_LENGTH,\n",
|
| 280 |
+
" num_train_epochs=NUM_TRAIN_EPOCHS,\n",
|
| 281 |
+
" logging_steps=LOGGING_STEPS,\n",
|
| 282 |
+
" save_steps=SAVE_STEPS,\n",
|
| 283 |
+
" bf16=bf16,\n",
|
| 284 |
+
" fp16=torch.cuda.is_available() and not bf16,\n",
|
| 285 |
+
")\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"trainer = GRPOTrainer(\n",
|
| 288 |
+
" model=model,\n",
|
| 289 |
+
" reward_funcs=[reward_fn],\n",
|
| 290 |
+
" args=training_args,\n",
|
| 291 |
+
" train_dataset=train_dataset,\n",
|
| 292 |
+
" processing_class=tokenizer,\n",
|
| 293 |
+
")\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"for attr in (\"image_token_id\", \"vision_start_token_id\", \"vision_end_token_id\"):\n",
|
| 296 |
+
" if not hasattr(trainer, attr):\n",
|
| 297 |
+
" setattr(trainer, attr, None)\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"trainer.train()\n",
|
| 300 |
+
"trainer.save_model(OUTPUT_DIR)\n",
|
| 301 |
+
"tokenizer.save_pretrained(OUTPUT_DIR)\n",
|
| 302 |
+
"plot_paths = save_training_plots(trainer.state.log_history, OUTPUT_DIR)\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"result = {\n",
|
| 305 |
+
" \"trainer\": trainer,\n",
|
| 306 |
+
" \"plot_paths\": plot_paths,\n",
|
| 307 |
+
" \"output_dir\": OUTPUT_DIR,\n",
|
| 308 |
+
"}\n",
|
| 309 |
+
"print(\"Saved to:\", OUTPUT_DIR)\n",
|
| 310 |
+
"print(\"Plots:\", plot_paths)"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"cell_type": "code",
|
| 315 |
+
"execution_count": null,
|
| 316 |
+
"metadata": {},
|
| 317 |
+
"outputs": [],
|
| 318 |
+
"source": [
|
| 319 |
+
"# 6. (Optional) Show curves and sanity-check the repo reward wrapper\n",
|
| 320 |
+
"from IPython.display import Image, display\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"sample_reward = reward_fn(\n",
|
| 323 |
+
" completions=[[{\"role\": \"assistant\", \"content\": examples[0][\"reference_action\"]}]],\n",
|
| 324 |
+
" history_actions=[examples[0][\"history_actions\"]],\n",
|
| 325 |
+
")[0]\n",
|
| 326 |
+
"print(\"Sample reward for reference action:\", sample_reward)\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"for name, path in (result.get(\"plot_paths\") or {}).items():\n",
|
| 329 |
+
" print(name, path)\n",
|
| 330 |
+
" display(Image(filename=path))"
|
| 331 |
+
]
|
| 332 |
+
}
|
| 333 |
+
],
|
| 334 |
+
"metadata": {
|
| 335 |
+
"kernelspec": {
|
| 336 |
+
"display_name": "Python 3",
|
| 337 |
+
"language": "python",
|
| 338 |
+
"name": "python3"
|
| 339 |
+
},
|
| 340 |
+
"language_info": {
|
| 341 |
+
"name": "python",
|
| 342 |
+
"version": "3.10.0"
|
| 343 |
+
}
|
| 344 |
+
},
|
| 345 |
+
"nbformat": 4,
|
| 346 |
+
"nbformat_minor": 4
|
| 347 |
+
}
|
pyproject.toml
CHANGED
|
@@ -40,8 +40,6 @@ train = [
|
|
| 40 |
"torch>=2.10.0",
|
| 41 |
"torchvision>=0.20.0", # required by transformers for Qwen3.5 (image_utils)
|
| 42 |
"transformers>=5.3.0", # 5.3+ required for Qwen3.5 (qwen3_5 model type)
|
| 43 |
-
"llm-blender>=0.0.2", # required by trl GRPOTrainer judges
|
| 44 |
-
"mergekit>=0.1.0", # required by trl GRPOTrainer/callbacks
|
| 45 |
"trl>=0.29.0", # GRPOTrainer; 0.29+ compatible with transformers 5.3
|
| 46 |
]
|
| 47 |
|
|
|
|
| 40 |
"torch>=2.10.0",
|
| 41 |
"torchvision>=0.20.0", # required by transformers for Qwen3.5 (image_utils)
|
| 42 |
"transformers>=5.3.0", # 5.3+ required for Qwen3.5 (qwen3_5 model type)
|
|
|
|
|
|
|
| 43 |
"trl>=0.29.0", # GRPOTrainer; 0.29+ compatible with transformers 5.3
|
| 44 |
]
|
| 45 |
|
tests/test_colab_train_llama32_remote.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from colab_train_llama32_remote import (
|
| 2 |
+
hf_space_repo_to_base_url,
|
| 3 |
+
make_training_args,
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_hf_space_repo_to_base_url_formats_standard_hf_space_domain():
|
| 8 |
+
assert (
|
| 9 |
+
hf_space_repo_to_base_url("Ev3Dev/hackathon")
|
| 10 |
+
== "https://ev3dev-hackathon.hf.space"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_make_training_args_derives_base_url_when_missing():
|
| 15 |
+
args = make_training_args(space_repo_id="Ev3Dev/hackathon", base_url="")
|
| 16 |
+
assert args.base_url == "https://ev3dev-hackathon.hf.space"
|
| 17 |
+
assert args.model_id == "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
|
unsloth_2048.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/UnslothCPOTrainer.py
CHANGED
|
@@ -28,7 +28,7 @@ import torch.nn as nn
|
|
| 28 |
from torch.nn import functional as F
|
| 29 |
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
-
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
|
| 32 |
|
| 33 |
|
| 34 |
import os
|
|
|
|
| 28 |
from torch.nn import functional as F
|
| 29 |
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, wandb, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, wandb, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
|
| 32 |
|
| 33 |
|
| 34 |
import os
|
unsloth_compiled_cache/UnslothDPOTrainer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/UnslothGRPOTrainer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/UnslothKTOTrainer.py
CHANGED
|
@@ -28,7 +28,7 @@ import torch.nn as nn
|
|
| 28 |
from torch.nn import functional as F
|
| 29 |
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
-
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, autocast, concatenate_datasets, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, TrainingArguments, Union, autocast, concatenate_datasets, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch, F, nn, np, os, selective_log_softmax, torch)
|
| 32 |
|
| 33 |
|
| 34 |
import os
|
|
|
|
| 28 |
from torch.nn import functional as F
|
| 29 |
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, autocast, concatenate_datasets, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, wandb, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, TrainingArguments, Union, autocast, concatenate_datasets, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, wandb, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch, F, nn, np, os, selective_log_softmax, torch)
|
| 32 |
|
| 33 |
|
| 34 |
import os
|
unsloth_compiled_cache/UnslothNashMDTrainer.py
ADDED
|
@@ -0,0 +1,1340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothNashMDConfig(NashMDConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`NashMDTrainer`].
|
| 328 |
+
|
| 329 |
+
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
| 330 |
+
|
| 331 |
+
Parameters:
|
| 332 |
+
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
|
| 333 |
+
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
|
| 334 |
+
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
|
| 335 |
+
epochs.
|
| 336 |
+
|
| 337 |
+
"""
|
| 338 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 339 |
+
default = None,
|
| 340 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 341 |
+
)
|
| 342 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 343 |
+
default = -1,
|
| 344 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 345 |
+
)
|
| 346 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 347 |
+
default = None,
|
| 348 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 349 |
+
)
|
| 350 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 351 |
+
default = None,
|
| 352 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 353 |
+
)
|
| 354 |
+
max_seq_length : Optional[int] = field(
|
| 355 |
+
default = None,
|
| 356 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 357 |
+
)
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
output_dir = None,
|
| 361 |
+
per_device_train_batch_size = 4,
|
| 362 |
+
num_train_epochs = 3.0,
|
| 363 |
+
max_steps = -1,
|
| 364 |
+
learning_rate = 5e-05,
|
| 365 |
+
lr_scheduler_type = 'linear',
|
| 366 |
+
lr_scheduler_kwargs = None,
|
| 367 |
+
warmup_steps = 0.1,
|
| 368 |
+
optim = 'adamw_8bit',
|
| 369 |
+
optim_args = None,
|
| 370 |
+
weight_decay = 0.01,
|
| 371 |
+
adam_beta1 = 0.9,
|
| 372 |
+
adam_beta2 = 0.999,
|
| 373 |
+
adam_epsilon = 1e-08,
|
| 374 |
+
optim_target_modules = None,
|
| 375 |
+
gradient_accumulation_steps = 2,
|
| 376 |
+
average_tokens_across_devices = True,
|
| 377 |
+
max_grad_norm = 1.0,
|
| 378 |
+
label_smoothing_factor = 0.0,
|
| 379 |
+
bf16 = False,
|
| 380 |
+
fp16 = False,
|
| 381 |
+
bf16_full_eval = False,
|
| 382 |
+
fp16_full_eval = False,
|
| 383 |
+
tf32 = None,
|
| 384 |
+
gradient_checkpointing = True,
|
| 385 |
+
gradient_checkpointing_kwargs = None,
|
| 386 |
+
torch_compile = False,
|
| 387 |
+
torch_compile_backend = None,
|
| 388 |
+
torch_compile_mode = None,
|
| 389 |
+
use_liger_kernel = False,
|
| 390 |
+
liger_kernel_config = None,
|
| 391 |
+
use_cache = False,
|
| 392 |
+
neftune_noise_alpha = None,
|
| 393 |
+
torch_empty_cache_steps = 250,
|
| 394 |
+
auto_find_batch_size = False,
|
| 395 |
+
logging_strategy = 'steps',
|
| 396 |
+
logging_steps = 1,
|
| 397 |
+
logging_first_step = False,
|
| 398 |
+
log_on_each_node = True,
|
| 399 |
+
logging_nan_inf_filter = False,
|
| 400 |
+
include_num_input_tokens_seen = False,
|
| 401 |
+
log_level = 'passive',
|
| 402 |
+
log_level_replica = 'warning',
|
| 403 |
+
disable_tqdm = None,
|
| 404 |
+
report_to = 'none',
|
| 405 |
+
run_name = None,
|
| 406 |
+
project = 'huggingface',
|
| 407 |
+
trackio_space_id = 'trackio',
|
| 408 |
+
eval_strategy = 'no',
|
| 409 |
+
eval_steps = None,
|
| 410 |
+
eval_delay = 0,
|
| 411 |
+
per_device_eval_batch_size = 4,
|
| 412 |
+
prediction_loss_only = False,
|
| 413 |
+
eval_on_start = False,
|
| 414 |
+
eval_do_concat_batches = True,
|
| 415 |
+
eval_use_gather_object = False,
|
| 416 |
+
eval_accumulation_steps = 2,
|
| 417 |
+
batch_eval_metrics = False,
|
| 418 |
+
save_only_model = False,
|
| 419 |
+
save_strategy = 'steps',
|
| 420 |
+
save_steps = 500,
|
| 421 |
+
save_on_each_node = False,
|
| 422 |
+
save_total_limit = None,
|
| 423 |
+
enable_jit_checkpoint = False,
|
| 424 |
+
push_to_hub = False,
|
| 425 |
+
hub_token = None,
|
| 426 |
+
hub_private_repo = None,
|
| 427 |
+
hub_model_id = None,
|
| 428 |
+
hub_strategy = 'every_save',
|
| 429 |
+
hub_always_push = False,
|
| 430 |
+
hub_revision = None,
|
| 431 |
+
load_best_model_at_end = False,
|
| 432 |
+
metric_for_best_model = None,
|
| 433 |
+
greater_is_better = None,
|
| 434 |
+
ignore_data_skip = False,
|
| 435 |
+
restore_callback_states_from_checkpoint = False,
|
| 436 |
+
full_determinism = False,
|
| 437 |
+
seed = 3407,
|
| 438 |
+
data_seed = 3407,
|
| 439 |
+
use_cpu = False,
|
| 440 |
+
accelerator_config = None,
|
| 441 |
+
parallelism_config = None,
|
| 442 |
+
dataloader_drop_last = False,
|
| 443 |
+
dataloader_num_workers = 0,
|
| 444 |
+
dataloader_pin_memory = True,
|
| 445 |
+
dataloader_persistent_workers = False,
|
| 446 |
+
dataloader_prefetch_factor = None,
|
| 447 |
+
remove_unused_columns = True,
|
| 448 |
+
label_names = None,
|
| 449 |
+
train_sampling_strategy = 'random',
|
| 450 |
+
length_column_name = 'length',
|
| 451 |
+
ddp_find_unused_parameters = None,
|
| 452 |
+
ddp_bucket_cap_mb = None,
|
| 453 |
+
ddp_broadcast_buffers = None,
|
| 454 |
+
ddp_backend = None,
|
| 455 |
+
ddp_timeout = 1800,
|
| 456 |
+
fsdp = None,
|
| 457 |
+
fsdp_config = None,
|
| 458 |
+
deepspeed = None,
|
| 459 |
+
debug = '',
|
| 460 |
+
skip_memory_metrics = True,
|
| 461 |
+
do_train = False,
|
| 462 |
+
do_eval = False,
|
| 463 |
+
do_predict = False,
|
| 464 |
+
resume_from_checkpoint = None,
|
| 465 |
+
warmup_ratio = None,
|
| 466 |
+
logging_dir = None,
|
| 467 |
+
local_rank = -1,
|
| 468 |
+
reward_model_path = None,
|
| 469 |
+
judge = None,
|
| 470 |
+
max_new_tokens = 64,
|
| 471 |
+
max_length = 512,
|
| 472 |
+
temperature = 0.9,
|
| 473 |
+
top_p = 1.0,
|
| 474 |
+
top_k = None,
|
| 475 |
+
min_p = None,
|
| 476 |
+
repetition_penalty = 1.0,
|
| 477 |
+
generation_kwargs = {},
|
| 478 |
+
use_transformers_paged = False,
|
| 479 |
+
cache_implementation = None,
|
| 480 |
+
missing_eos_penalty = None,
|
| 481 |
+
loss_type = 'sigmoid',
|
| 482 |
+
disable_dropout = True,
|
| 483 |
+
use_vllm = False,
|
| 484 |
+
vllm_model_impl = 'vllm',
|
| 485 |
+
vllm_guided_decoding_regex = None,
|
| 486 |
+
vllm_gpu_memory_utilization = 0.55,
|
| 487 |
+
vllm_mode = 'colocate',
|
| 488 |
+
vllm_server_base_url = None,
|
| 489 |
+
vllm_server_host = '0.0.0.0',
|
| 490 |
+
vllm_server_port = 8000,
|
| 491 |
+
vllm_server_timeout = 240.0,
|
| 492 |
+
vllm_tensor_parallel_size = 1,
|
| 493 |
+
ds3_gather_for_generation = True,
|
| 494 |
+
model_init_kwargs = None,
|
| 495 |
+
reward_weights = None,
|
| 496 |
+
dataset_num_proc = None,
|
| 497 |
+
gpu_memory_utilization = None,
|
| 498 |
+
vllm_sampling_params = None,
|
| 499 |
+
unsloth_num_chunks = -1,
|
| 500 |
+
unsloth_logit_chunk_multiplier = None,
|
| 501 |
+
unsloth_grpo_mini_batch = None,
|
| 502 |
+
max_seq_length = None,
|
| 503 |
+
**kwargs,
|
| 504 |
+
):
|
| 505 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 506 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 507 |
+
if num_train_epochs is None:
|
| 508 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 509 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 510 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 511 |
+
save_strategy = 'no'
|
| 512 |
+
import multiprocessing as _mp
|
| 513 |
+
if _mp.get_start_method() != 'fork':
|
| 514 |
+
dataset_num_proc = None
|
| 515 |
+
elif dataset_num_proc is None:
|
| 516 |
+
import psutil
|
| 517 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 518 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 519 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 520 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 521 |
+
if temperature <= 0:
|
| 522 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 523 |
+
elif temperature >= 10:
|
| 524 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
super().__init__(
|
| 528 |
+
output_dir = output_dir,
|
| 529 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 530 |
+
num_train_epochs = num_train_epochs,
|
| 531 |
+
max_steps = max_steps,
|
| 532 |
+
learning_rate = learning_rate,
|
| 533 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 534 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 535 |
+
warmup_steps = warmup_steps,
|
| 536 |
+
optim = optim,
|
| 537 |
+
optim_args = optim_args,
|
| 538 |
+
weight_decay = weight_decay,
|
| 539 |
+
adam_beta1 = adam_beta1,
|
| 540 |
+
adam_beta2 = adam_beta2,
|
| 541 |
+
adam_epsilon = adam_epsilon,
|
| 542 |
+
optim_target_modules = optim_target_modules,
|
| 543 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 544 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 545 |
+
max_grad_norm = max_grad_norm,
|
| 546 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 547 |
+
bf16 = bf16,
|
| 548 |
+
fp16 = fp16,
|
| 549 |
+
bf16_full_eval = bf16_full_eval,
|
| 550 |
+
fp16_full_eval = fp16_full_eval,
|
| 551 |
+
tf32 = tf32,
|
| 552 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 553 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 554 |
+
torch_compile = torch_compile,
|
| 555 |
+
torch_compile_backend = torch_compile_backend,
|
| 556 |
+
torch_compile_mode = torch_compile_mode,
|
| 557 |
+
use_liger_kernel = use_liger_kernel,
|
| 558 |
+
liger_kernel_config = liger_kernel_config,
|
| 559 |
+
use_cache = use_cache,
|
| 560 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 561 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 562 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 563 |
+
logging_strategy = logging_strategy,
|
| 564 |
+
logging_steps = logging_steps,
|
| 565 |
+
logging_first_step = logging_first_step,
|
| 566 |
+
log_on_each_node = log_on_each_node,
|
| 567 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 568 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 569 |
+
log_level = log_level,
|
| 570 |
+
log_level_replica = log_level_replica,
|
| 571 |
+
disable_tqdm = disable_tqdm,
|
| 572 |
+
report_to = report_to,
|
| 573 |
+
run_name = run_name,
|
| 574 |
+
project = project,
|
| 575 |
+
trackio_space_id = trackio_space_id,
|
| 576 |
+
eval_strategy = eval_strategy,
|
| 577 |
+
eval_steps = eval_steps,
|
| 578 |
+
eval_delay = eval_delay,
|
| 579 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 580 |
+
prediction_loss_only = prediction_loss_only,
|
| 581 |
+
eval_on_start = eval_on_start,
|
| 582 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 583 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 584 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 585 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 586 |
+
save_only_model = save_only_model,
|
| 587 |
+
save_strategy = save_strategy,
|
| 588 |
+
save_steps = save_steps,
|
| 589 |
+
save_on_each_node = save_on_each_node,
|
| 590 |
+
save_total_limit = save_total_limit,
|
| 591 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 592 |
+
push_to_hub = push_to_hub,
|
| 593 |
+
hub_token = hub_token,
|
| 594 |
+
hub_private_repo = hub_private_repo,
|
| 595 |
+
hub_model_id = hub_model_id,
|
| 596 |
+
hub_strategy = hub_strategy,
|
| 597 |
+
hub_always_push = hub_always_push,
|
| 598 |
+
hub_revision = hub_revision,
|
| 599 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 600 |
+
metric_for_best_model = metric_for_best_model,
|
| 601 |
+
greater_is_better = greater_is_better,
|
| 602 |
+
ignore_data_skip = ignore_data_skip,
|
| 603 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 604 |
+
full_determinism = full_determinism,
|
| 605 |
+
seed = seed,
|
| 606 |
+
data_seed = data_seed,
|
| 607 |
+
use_cpu = use_cpu,
|
| 608 |
+
accelerator_config = accelerator_config,
|
| 609 |
+
parallelism_config = parallelism_config,
|
| 610 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 611 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 612 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 613 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 614 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 615 |
+
remove_unused_columns = remove_unused_columns,
|
| 616 |
+
label_names = label_names,
|
| 617 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 618 |
+
length_column_name = length_column_name,
|
| 619 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 620 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 621 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 622 |
+
ddp_backend = ddp_backend,
|
| 623 |
+
ddp_timeout = ddp_timeout,
|
| 624 |
+
fsdp = fsdp,
|
| 625 |
+
fsdp_config = fsdp_config,
|
| 626 |
+
deepspeed = deepspeed,
|
| 627 |
+
debug = debug,
|
| 628 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 629 |
+
do_train = do_train,
|
| 630 |
+
do_eval = do_eval,
|
| 631 |
+
do_predict = do_predict,
|
| 632 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 633 |
+
warmup_ratio = warmup_ratio,
|
| 634 |
+
logging_dir = logging_dir,
|
| 635 |
+
local_rank = local_rank,
|
| 636 |
+
reward_model_path = reward_model_path,
|
| 637 |
+
judge = judge,
|
| 638 |
+
max_new_tokens = max_new_tokens,
|
| 639 |
+
max_length = max_length,
|
| 640 |
+
temperature = temperature,
|
| 641 |
+
top_p = top_p,
|
| 642 |
+
top_k = top_k,
|
| 643 |
+
min_p = min_p,
|
| 644 |
+
repetition_penalty = repetition_penalty,
|
| 645 |
+
generation_kwargs = generation_kwargs,
|
| 646 |
+
use_transformers_paged = use_transformers_paged,
|
| 647 |
+
cache_implementation = cache_implementation,
|
| 648 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 649 |
+
loss_type = loss_type,
|
| 650 |
+
disable_dropout = disable_dropout,
|
| 651 |
+
use_vllm = use_vllm,
|
| 652 |
+
vllm_model_impl = vllm_model_impl,
|
| 653 |
+
vllm_guided_decoding_regex = vllm_guided_decoding_regex,
|
| 654 |
+
vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
|
| 655 |
+
vllm_mode = vllm_mode,
|
| 656 |
+
vllm_server_base_url = vllm_server_base_url,
|
| 657 |
+
vllm_server_host = vllm_server_host,
|
| 658 |
+
vllm_server_port = vllm_server_port,
|
| 659 |
+
vllm_server_timeout = vllm_server_timeout,
|
| 660 |
+
vllm_tensor_parallel_size = vllm_tensor_parallel_size,
|
| 661 |
+
ds3_gather_for_generation = ds3_gather_for_generation,
|
| 662 |
+
model_init_kwargs = model_init_kwargs,
|
| 663 |
+
reward_weights = reward_weights,
|
| 664 |
+
dataset_num_proc = dataset_num_proc,
|
| 665 |
+
gpu_memory_utilization = gpu_memory_utilization,**kwargs)
|
| 666 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 667 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 668 |
+
if unsloth_grpo_mini_batch is not None:
|
| 669 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 670 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 671 |
+
else:
|
| 672 |
+
raise ValueError(
|
| 673 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 674 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 675 |
+
)
|
| 676 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 677 |
+
self.max_seq_length = max_seq_length
|
| 678 |
+
|
| 679 |
+
pass
|
| 680 |
+
|
| 681 |
+
class _UnslothNashMDTrainer(OnlineDPOTrainer):
|
| 682 |
+
""""""
|
| 683 |
+
|
| 684 |
+
_tag_names = ["trl", "nash-md"]
|
| 685 |
+
_name = "Nash-MD"
|
| 686 |
+
_paper = {
|
| 687 |
+
"title": "Nash Learning from Human Feedback",
|
| 688 |
+
"id": "2312.00886",
|
| 689 |
+
# docstyle-ignore
|
| 690 |
+
"citation": textwrap.dedent("""\
|
| 691 |
+
@inproceedings{munos2024nash,
|
| 692 |
+
title = {{Nash Learning from Human Feedback}},
|
| 693 |
+
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
|
| 694 |
+
year = 2024,
|
| 695 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
| 696 |
+
publisher = {OpenReview.net},
|
| 697 |
+
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
|
| 698 |
+
}"""),
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
def __init__(
|
| 702 |
+
self,
|
| 703 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 704 |
+
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
| 705 |
+
reward_funcs: Union[PreTrainedModel, nn.Module, None] = None,
|
| 706 |
+
judge: Optional[BasePairwiseJudge] = None,
|
| 707 |
+
args: Optional[NashMDConfig] = None,
|
| 708 |
+
data_collator: Optional[Callable] = None,
|
| 709 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 710 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 711 |
+
processing_class: Optional[
|
| 712 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 713 |
+
] = None,
|
| 714 |
+
peft_config: Optional[dict] = None,
|
| 715 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 716 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 717 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 718 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 719 |
+
# Deprecated parameters
|
| 720 |
+
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 721 |
+
) -> None:
|
| 722 |
+
super().__init__(
|
| 723 |
+
model=model,
|
| 724 |
+
ref_model=ref_model,
|
| 725 |
+
reward_funcs=reward_funcs,
|
| 726 |
+
judge=judge,
|
| 727 |
+
args=args,
|
| 728 |
+
data_collator=data_collator,
|
| 729 |
+
train_dataset=train_dataset,
|
| 730 |
+
eval_dataset=eval_dataset,
|
| 731 |
+
processing_class=processing_class,
|
| 732 |
+
reward_processing_classes=processing_class,
|
| 733 |
+
peft_config=peft_config,
|
| 734 |
+
compute_metrics=compute_metrics,
|
| 735 |
+
callbacks=callbacks,
|
| 736 |
+
optimizers=optimizers,
|
| 737 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 738 |
+
reward_model=reward_model,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
self._mixture_coef = self.args.mixture_coef
|
| 742 |
+
|
| 743 |
+
# Overwrite the stats dictionary to include NashMD specific statistics
|
| 744 |
+
self.stats = {
|
| 745 |
+
# Remove "non_score_reward", "rlhf_reward", "scores_margin"
|
| 746 |
+
# Add "mixture_coef"
|
| 747 |
+
"loss/kl": [],
|
| 748 |
+
"objective/entropy": [],
|
| 749 |
+
"loss/score": [],
|
| 750 |
+
"rewards/probabilities": [],
|
| 751 |
+
"rewards/accuracies": [],
|
| 752 |
+
"rewards/margins": [],
|
| 753 |
+
"logps/chosen": [],
|
| 754 |
+
"logps/rejected": [],
|
| 755 |
+
"val/model_contain_eos_token": [],
|
| 756 |
+
"val/ref_contain_eos_token": [],
|
| 757 |
+
"beta": [],
|
| 758 |
+
"mixture_coef": [],
|
| 759 |
+
}
|
| 760 |
+
if self.reward_funcs is not None:
|
| 761 |
+
if len(self.reward_funcs) != 1:
|
| 762 |
+
raise ValueError("NashMDTrainer only supports one reward function/model.")
|
| 763 |
+
self.reward_funcs = self.reward_funcs[0]
|
| 764 |
+
self.stats["rewards/chosen"] = []
|
| 765 |
+
self.stats["rewards/rejected"] = []
|
| 766 |
+
|
| 767 |
+
@property
|
| 768 |
+
def mixture_coef(self):
|
| 769 |
+
if isinstance(self._mixture_coef, list):
|
| 770 |
+
epoch = self.state.epoch
|
| 771 |
+
return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
|
| 772 |
+
else:
|
| 773 |
+
return self._mixture_coef
|
| 774 |
+
|
| 775 |
+
def _generate_completions(self, model, prompts):
|
| 776 |
+
# Generate completions from the policy model.
|
| 777 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
|
| 778 |
+
model_output = unwrapped_policy_for_gen_ctx.generate(
|
| 779 |
+
input_ids=prompts["input_ids"],
|
| 780 |
+
attention_mask=prompts["attention_mask"],
|
| 781 |
+
generation_config=self.generation_config,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Get the DDP/FSDP unwrapped version of the main model.
|
| 785 |
+
# This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
|
| 786 |
+
policy_model_for_gmw = self.accelerator.unwrap_model(model)
|
| 787 |
+
|
| 788 |
+
# Determine the correct reference model for GeometricMixtureWrapper.
|
| 789 |
+
# This also needs to be DDP/FSDP unwrapped.
|
| 790 |
+
ref_model_for_gmw: torch.nn.Module
|
| 791 |
+
if self.ref_model is None:
|
| 792 |
+
# No explicit ref_model is provided.
|
| 793 |
+
# Use the base of the main `model` if it's a PEFT model.
|
| 794 |
+
# policy_model_for_gmw is already DDP-unwrapped.
|
| 795 |
+
if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
|
| 796 |
+
ref_model_for_gmw = policy_model_for_gmw.get_base_model()
|
| 797 |
+
else:
|
| 798 |
+
# Not a PEFT model (or PEFT not available), or already a base model.
|
| 799 |
+
# Use the DDP-unwrapped policy model itself as the reference.
|
| 800 |
+
ref_model_for_gmw = policy_model_for_gmw
|
| 801 |
+
else:
|
| 802 |
+
# An explicit ref_model is provided. Unwrap it for DDP/FSDP.
|
| 803 |
+
ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
|
| 804 |
+
|
| 805 |
+
# Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
|
| 806 |
+
with torch.no_grad(): # Ensure no_grad context for mixture model generation
|
| 807 |
+
mixture_model = GeometricMixtureWrapper(
|
| 808 |
+
model=policy_model_for_gmw,
|
| 809 |
+
ref_model=ref_model_for_gmw,
|
| 810 |
+
generation_config=self.generation_config,
|
| 811 |
+
mixture_coef=self.mixture_coef,
|
| 812 |
+
device=self.accelerator.device,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
mixture_output = mixture_model.generate(
|
| 816 |
+
input_ids=prompts["input_ids"],
|
| 817 |
+
attention_mask=prompts["attention_mask"],
|
| 818 |
+
generation_config=self.generation_config,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
return model_output, mixture_output
|
| 822 |
+
|
| 823 |
+
def _process_completions(self, model_output, mixture_output, prompts):
|
| 824 |
+
context_length = prompts["input_ids"].shape[1]
|
| 825 |
+
|
| 826 |
+
# Process model completions
|
| 827 |
+
model_completion_ids = model_output[:, context_length:]
|
| 828 |
+
model_completion_ids, model_completion_mask = truncate_right(
|
| 829 |
+
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 830 |
+
)
|
| 831 |
+
model_data = {
|
| 832 |
+
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
| 833 |
+
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
| 834 |
+
"raw": prompts["raw"],
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
# Process reference model completions
|
| 838 |
+
mixture_completion_ids = mixture_output[:, context_length:]
|
| 839 |
+
mixture_completion_ids, mixture_completion_mask = truncate_right(
|
| 840 |
+
mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 841 |
+
)
|
| 842 |
+
mixture_data = {
|
| 843 |
+
"input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
|
| 844 |
+
"attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
|
| 845 |
+
"raw": prompts["raw"],
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
return model_data, mixture_data
|
| 849 |
+
|
| 850 |
+
def _compute_rewards(self, model_data, mixture_data, context_length):
|
| 851 |
+
with torch.no_grad():
|
| 852 |
+
_, model_scores, _ = get_reward(
|
| 853 |
+
self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 854 |
+
)
|
| 855 |
+
_, mixture_scores, _ = get_reward(
|
| 856 |
+
self.reward_funcs, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Apply EOS penalty if needed
|
| 860 |
+
if self.args.missing_eos_penalty is not None:
|
| 861 |
+
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 862 |
+
mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 863 |
+
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
| 864 |
+
mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
|
| 865 |
+
|
| 866 |
+
return model_scores, mixture_scores
|
| 867 |
+
|
| 868 |
+
def _compute_judge(self, model_data, mixture_data, context_length):
|
| 869 |
+
prompts = model_data["raw"]
|
| 870 |
+
model_data_completions = self.processing_class.batch_decode(
|
| 871 |
+
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 872 |
+
)
|
| 873 |
+
model_data_completions = [completion.strip() for completion in model_data_completions]
|
| 874 |
+
|
| 875 |
+
mixture_data_completions = self.processing_class.batch_decode(
|
| 876 |
+
mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 877 |
+
)
|
| 878 |
+
mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
|
| 879 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 880 |
+
model_data_completions = [
|
| 881 |
+
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
| 882 |
+
]
|
| 883 |
+
environment = jinja2.Environment()
|
| 884 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 885 |
+
prompts = [template.render(messages=message) for message in prompts]
|
| 886 |
+
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
| 887 |
+
|
| 888 |
+
mixture_data_completions = [
|
| 889 |
+
[{"role": "assistant", "content": completion}] for completion in mixture_data_completions
|
| 890 |
+
]
|
| 891 |
+
mixture_data_completions = [
|
| 892 |
+
template.render(messages=completion) for completion in mixture_data_completions
|
| 893 |
+
]
|
| 894 |
+
|
| 895 |
+
probability = self.judge.judge(
|
| 896 |
+
prompts,
|
| 897 |
+
list(zip(model_data_completions, mixture_data_completions)),
|
| 898 |
+
return_scores=True,
|
| 899 |
+
)
|
| 900 |
+
return torch.tensor(probability, device=model_data["input_ids"].device)
|
| 901 |
+
|
| 902 |
+
def _compute_logprobs(self, model, model_data, context_length):
|
| 903 |
+
def compute_logprobs_for_data(m, data):
|
| 904 |
+
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
| 905 |
+
logits = output.logits[:, context_length - 1 : -1]
|
| 906 |
+
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
| 907 |
+
return token_logprobs
|
| 908 |
+
|
| 909 |
+
# Compute logprobs for model completions under the model
|
| 910 |
+
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 911 |
+
|
| 912 |
+
# Compute logprobs of model completions under the reference model
|
| 913 |
+
with torch.no_grad():
|
| 914 |
+
if self.ref_model is None:
|
| 915 |
+
with model.disable_adapter():
|
| 916 |
+
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 917 |
+
else:
|
| 918 |
+
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
| 919 |
+
|
| 920 |
+
# Mask padding tokens
|
| 921 |
+
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
| 922 |
+
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 923 |
+
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 924 |
+
|
| 925 |
+
return (model_logprobs_model_data, ref_logprobs_model_data)
|
| 926 |
+
|
| 927 |
+
def _compute_losses(
|
| 928 |
+
self,
|
| 929 |
+
model_logprobs_model_data,
|
| 930 |
+
ref_logprobs_model_data,
|
| 931 |
+
probability,
|
| 932 |
+
):
|
| 933 |
+
# reinforce score where 0.5 is a control variate
|
| 934 |
+
score = (probability - 0.5) * model_logprobs_model_data.sum(1)
|
| 935 |
+
|
| 936 |
+
# kl divergence via reinforce
|
| 937 |
+
with torch.no_grad():
|
| 938 |
+
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
|
| 939 |
+
kl_div_log = log_ratio.sum(1)
|
| 940 |
+
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
|
| 941 |
+
|
| 942 |
+
# final loss
|
| 943 |
+
loss = self.beta * kl_div_loss - score
|
| 944 |
+
|
| 945 |
+
return loss.mean(), score, kl_div_log
|
| 946 |
+
|
| 947 |
+
def _log_statistics(
|
| 948 |
+
self,
|
| 949 |
+
model_data,
|
| 950 |
+
mixture_data,
|
| 951 |
+
model_logprobs_model_data,
|
| 952 |
+
ref_logprobs_model_data,
|
| 953 |
+
probability,
|
| 954 |
+
score,
|
| 955 |
+
kl_div,
|
| 956 |
+
context_length,
|
| 957 |
+
model_scores=None,
|
| 958 |
+
mixture_scores=None,
|
| 959 |
+
):
|
| 960 |
+
# Helper function to gather and compute mean
|
| 961 |
+
def gather_mean(tensor):
|
| 962 |
+
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
| 963 |
+
|
| 964 |
+
# Log score
|
| 965 |
+
self.stats["loss/score"].append(gather_mean(score))
|
| 966 |
+
# Log KL divergence
|
| 967 |
+
self.stats["loss/kl"].append(gather_mean(kl_div))
|
| 968 |
+
|
| 969 |
+
# Log logprobs
|
| 970 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 971 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 972 |
+
|
| 973 |
+
self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
|
| 974 |
+
self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
|
| 975 |
+
|
| 976 |
+
# Log rewards
|
| 977 |
+
if self.reward_funcs is not None:
|
| 978 |
+
self.stats["rewards/chosen"].append(gather_mean(model_scores))
|
| 979 |
+
self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
|
| 980 |
+
|
| 981 |
+
# Log probabilities
|
| 982 |
+
self.stats["rewards/probabilities"].append(gather_mean(probability))
|
| 983 |
+
|
| 984 |
+
# Calculate entropy for model data
|
| 985 |
+
entropy_model_data = -model_logprobs_model_data.sum(1)
|
| 986 |
+
self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
|
| 987 |
+
|
| 988 |
+
# Calculate margins
|
| 989 |
+
margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
|
| 990 |
+
self.stats["rewards/margins"].append(gather_mean(margin))
|
| 991 |
+
|
| 992 |
+
# Calculate accuracy
|
| 993 |
+
accuracy = (margin > 0).float()
|
| 994 |
+
self.stats["rewards/accuracies"].append(gather_mean(accuracy))
|
| 995 |
+
|
| 996 |
+
# Log EOS token statistics
|
| 997 |
+
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 998 |
+
mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 999 |
+
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
| 1000 |
+
self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
|
| 1001 |
+
|
| 1002 |
+
# Log beta and mixture coef
|
| 1003 |
+
self.stats["beta"].append(self.beta)
|
| 1004 |
+
self.stats["mixture_coef"].append(self.mixture_coef)
|
| 1005 |
+
|
| 1006 |
+
def training_step(
|
| 1007 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 1008 |
+
) -> torch.Tensor:
|
| 1009 |
+
model.train()
|
| 1010 |
+
|
| 1011 |
+
# Apply chat template and tokenize the input
|
| 1012 |
+
batch_size = len(next(iter(inputs.values())))
|
| 1013 |
+
prompts = inputs["prompt"]
|
| 1014 |
+
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
| 1015 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 1016 |
+
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 1017 |
+
inputs = self.data_collator(inputs)
|
| 1018 |
+
|
| 1019 |
+
# need the prompt_ only
|
| 1020 |
+
inputs = self._prepare_inputs(inputs)
|
| 1021 |
+
context_length = inputs["prompt_input_ids"].shape[1]
|
| 1022 |
+
prompts = {
|
| 1023 |
+
"input_ids": inputs["prompt_input_ids"],
|
| 1024 |
+
"attention_mask": inputs["prompt_attention_mask"],
|
| 1025 |
+
"raw": prompts,
|
| 1026 |
+
}
|
| 1027 |
+
del inputs
|
| 1028 |
+
|
| 1029 |
+
# Sample completions from both the model and the reference model
|
| 1030 |
+
model_output, mixture_output = self._generate_completions(model, prompts)
|
| 1031 |
+
|
| 1032 |
+
# Process model completions
|
| 1033 |
+
model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
|
| 1034 |
+
|
| 1035 |
+
# Compute rewards
|
| 1036 |
+
if self.reward_funcs is not None:
|
| 1037 |
+
model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
|
| 1038 |
+
# probability of the model data vs the mixture data
|
| 1039 |
+
probability = F.sigmoid(model_scores - mixture_scores)
|
| 1040 |
+
else:
|
| 1041 |
+
model_scores, mixture_scores = None, None
|
| 1042 |
+
probability = self._compute_judge(model_data, mixture_data, context_length)
|
| 1043 |
+
|
| 1044 |
+
# Compute logprobs
|
| 1045 |
+
model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
|
| 1046 |
+
|
| 1047 |
+
# Compute loss
|
| 1048 |
+
loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
|
| 1049 |
+
|
| 1050 |
+
# Log everything
|
| 1051 |
+
self._log_statistics(
|
| 1052 |
+
model_data,
|
| 1053 |
+
mixture_data,
|
| 1054 |
+
model_logprobs_model_data.detach(),
|
| 1055 |
+
ref_logprobs_model_data,
|
| 1056 |
+
probability,
|
| 1057 |
+
score.detach(),
|
| 1058 |
+
kl_div.detach(),
|
| 1059 |
+
context_length,
|
| 1060 |
+
model_scores,
|
| 1061 |
+
mixture_scores,
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
if (
|
| 1065 |
+
self.args.torch_empty_cache_steps is not None
|
| 1066 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 1067 |
+
):
|
| 1068 |
+
empty_cache()
|
| 1069 |
+
|
| 1070 |
+
kwargs = {}
|
| 1071 |
+
# For LOMO optimizers you need to explicitly use the learning rate
|
| 1072 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 1073 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
| 1074 |
+
|
| 1075 |
+
if self.args.n_gpu > 1:
|
| 1076 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 1077 |
+
|
| 1078 |
+
self.accelerator.backward(loss, **kwargs)
|
| 1079 |
+
|
| 1080 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
| 1081 |
+
class UnslothNashMDTrainer(_UnslothNashMDTrainer):
|
| 1082 |
+
"""
|
| 1083 |
+
|
| 1084 |
+
Trainer for the Nash-MD method.
|
| 1085 |
+
|
| 1086 |
+
It is implemented as a subclass of [`OnlineDPOTrainer`].
|
| 1087 |
+
|
| 1088 |
+
Args:
|
| 1089 |
+
model ([`~transformers.PreTrainedModel`]):
|
| 1090 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
| 1091 |
+
ref_model ([`PreTrainedModelWrapper`]):
|
| 1092 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
|
| 1093 |
+
and loss. If no reference model is provided, the trainer will create a reference model with the same
|
| 1094 |
+
architecture as the model to be optimized.
|
| 1095 |
+
reward_funcs ([`~transformers.PreTrainedModel`]):
|
| 1096 |
+
The reward model to score completions with, preferably an
|
| 1097 |
+
[`~transformers.AutoModelForSequenceClassification`].
|
| 1098 |
+
judge ([`BasePairwiseJudge`]):
|
| 1099 |
+
The judge to use for pairwise comparison of model completions.
|
| 1100 |
+
args ([`NashMDConfig`]):
|
| 1101 |
+
The NashMD config arguments to use for training.
|
| 1102 |
+
data_collator ([`~transformers.DataCollator`]):
|
| 1103 |
+
The data collator to use for training. If None is specified, the default data collator
|
| 1104 |
+
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
| 1105 |
+
sequences in the batch, given a dataset of paired sequences.
|
| 1106 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 1107 |
+
The dataset to use for training.
|
| 1108 |
+
eval_dataset ([`~datasets.Dataset`]):
|
| 1109 |
+
The dataset to use for evaluation.
|
| 1110 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 1111 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1112 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1113 |
+
reuse the fine-tuned model.
|
| 1114 |
+
peft_config (`dict`):
|
| 1115 |
+
The peft config to use for training.
|
| 1116 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1117 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
|
| 1118 |
+
metric values.
|
| 1119 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1120 |
+
The callbacks to use for training.
|
| 1121 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1122 |
+
The optimizer and scheduler to use for training.
|
| 1123 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1124 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1125 |
+
|
| 1126 |
+
reward_model:
|
| 1127 |
+
|
| 1128 |
+
<Deprecated version="0.22.0">
|
| 1129 |
+
|
| 1130 |
+
This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
|
| 1131 |
+
|
| 1132 |
+
</Deprecated>
|
| 1133 |
+
|
| 1134 |
+
"""
|
| 1135 |
+
def __init__(
|
| 1136 |
+
self,
|
| 1137 |
+
model = None,
|
| 1138 |
+
ref_model = None,
|
| 1139 |
+
reward_funcs = None,
|
| 1140 |
+
judge = None,
|
| 1141 |
+
args = None,
|
| 1142 |
+
data_collator = None,
|
| 1143 |
+
train_dataset = None,
|
| 1144 |
+
eval_dataset = None,
|
| 1145 |
+
processing_class = None,
|
| 1146 |
+
peft_config = None,
|
| 1147 |
+
compute_metrics = None,
|
| 1148 |
+
callbacks = None,
|
| 1149 |
+
preprocess_logits_for_metrics = None,
|
| 1150 |
+
reward_model = None,
|
| 1151 |
+
**kwargs
|
| 1152 |
+
):
|
| 1153 |
+
if args is None: args = UnslothNashMDConfig()
|
| 1154 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1155 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1156 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1157 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1158 |
+
force_float32 = False
|
| 1159 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1160 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1161 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1162 |
+
force_float32 = True
|
| 1163 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1164 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1165 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1166 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1167 |
+
dtype = _get_dtype(dtype)
|
| 1168 |
+
float16 = dtype == torch.float16
|
| 1169 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1170 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1171 |
+
if force_float32:
|
| 1172 |
+
# Forced float32 training
|
| 1173 |
+
args.fp16 = False
|
| 1174 |
+
args.bf16 = False
|
| 1175 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1176 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1177 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1178 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1179 |
+
# Mixed precision training
|
| 1180 |
+
args.fp16 = float16
|
| 1181 |
+
args.bf16 = not float16
|
| 1182 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1183 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1184 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1185 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1186 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1187 |
+
args.fp16 = False
|
| 1188 |
+
args.bf16 = False
|
| 1189 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1190 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1191 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1192 |
+
|
| 1193 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1194 |
+
args.eval_strategy = 'steps'
|
| 1195 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1196 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1197 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1198 |
+
from transformers import __version__ as transformers_version
|
| 1199 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1200 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1201 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1202 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1203 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1204 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1205 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1206 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1207 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1208 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1209 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1210 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1211 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1212 |
+
if force_float32:
|
| 1213 |
+
args.bf16_full_eval = False
|
| 1214 |
+
args.fp16_full_eval = False
|
| 1215 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1216 |
+
args.bf16_full_eval = True
|
| 1217 |
+
args.fp16_full_eval = False
|
| 1218 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1219 |
+
args.bf16_full_eval = args.bf16
|
| 1220 |
+
args.fp16_full_eval = args.fp16
|
| 1221 |
+
_output_logits = False
|
| 1222 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1223 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1224 |
+
if _output_logits:
|
| 1225 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1226 |
+
if model is not None:
|
| 1227 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1228 |
+
if _warnings_issued is None:
|
| 1229 |
+
model.warnings_issued = {}
|
| 1230 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1231 |
+
try:
|
| 1232 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1233 |
+
except Exception:
|
| 1234 |
+
model.warnings_issued = {}
|
| 1235 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1236 |
+
pass
|
| 1237 |
+
else:
|
| 1238 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1239 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1240 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1241 |
+
max_seq_length = model.max_seq_length
|
| 1242 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1243 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1244 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1245 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1246 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1247 |
+
args.max_seq_length = model_max_seq_length
|
| 1248 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1249 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1250 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1251 |
+
if 'processing_class' in locals():
|
| 1252 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1253 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1254 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1255 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1256 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1257 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1258 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1259 |
+
__tokenizer,
|
| 1260 |
+
mlm = False,
|
| 1261 |
+
mlm_probability = 0.0,
|
| 1262 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1263 |
+
)
|
| 1264 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1265 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1266 |
+
__tokenizer,
|
| 1267 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1268 |
+
)
|
| 1269 |
+
else:
|
| 1270 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1271 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1272 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1273 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1274 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1275 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1276 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1277 |
+
__tokenizer.tokenizer,
|
| 1278 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1279 |
+
)
|
| 1280 |
+
else:
|
| 1281 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1282 |
+
__tokenizer.tokenizer,
|
| 1283 |
+
mlm = False,
|
| 1284 |
+
mlm_probability = 0.0,
|
| 1285 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1286 |
+
)
|
| 1287 |
+
other_metrics = []
|
| 1288 |
+
|
| 1289 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1290 |
+
PatchRLStatistics('nash_md_trainer', other_metrics)
|
| 1291 |
+
|
| 1292 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1293 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1294 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1295 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1296 |
+
args._n_gpu = 1
|
| 1297 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1298 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1299 |
+
super().__init__(
|
| 1300 |
+
model = model,
|
| 1301 |
+
ref_model = ref_model,
|
| 1302 |
+
reward_funcs = reward_funcs,
|
| 1303 |
+
judge = judge,
|
| 1304 |
+
args = args,
|
| 1305 |
+
data_collator = data_collator,
|
| 1306 |
+
train_dataset = train_dataset,
|
| 1307 |
+
eval_dataset = eval_dataset,
|
| 1308 |
+
processing_class = processing_class,
|
| 1309 |
+
peft_config = peft_config,
|
| 1310 |
+
compute_metrics = compute_metrics,
|
| 1311 |
+
callbacks = callbacks,
|
| 1312 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1313 |
+
reward_model = reward_model,**kwargs)
|
| 1314 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1315 |
+
model.for_inference()
|
| 1316 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1317 |
+
self.neftune_hook_handle.remove()
|
| 1318 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1319 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1320 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1321 |
+
pass
|
| 1322 |
+
if hasattr(self, 'accelerator'):
|
| 1323 |
+
scaler = self.accelerator.scaler
|
| 1324 |
+
current_model = model
|
| 1325 |
+
while hasattr(current_model, 'model'):
|
| 1326 |
+
current_model.accelerator_scaler = scaler
|
| 1327 |
+
current_model = current_model.model
|
| 1328 |
+
current_model.accelerator_scaler = scaler
|
| 1329 |
+
pass
|
| 1330 |
+
if hasattr(self, 'train'):
|
| 1331 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1332 |
+
pass
|
| 1333 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1334 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1335 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1336 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1337 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1338 |
+
pass
|
| 1339 |
+
|
| 1340 |
+
pass
|
unsloth_compiled_cache/UnslothORPOTrainer.py
CHANGED
|
@@ -28,7 +28,7 @@ import torch.nn as nn
|
|
| 28 |
from torch.nn import functional as F
|
| 29 |
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
-
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
|
| 32 |
|
| 33 |
|
| 34 |
import os
|
|
|
|
| 28 |
from torch.nn import functional as F
|
| 29 |
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, wandb, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, wandb, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
|
| 32 |
|
| 33 |
|
| 34 |
import os
|
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/UnslothRLOOTrainer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/UnslothXPOTrainer.py
ADDED
|
@@ -0,0 +1,1385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothXPOConfig(XPOConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`XPOTrainer`].
|
| 328 |
+
|
| 329 |
+
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
| 330 |
+
|
| 331 |
+
Parameters:
|
| 332 |
+
alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
|
| 333 |
+
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
|
| 334 |
+
and the last alpha is used for the rest of the epochs.
|
| 335 |
+
|
| 336 |
+
"""
|
| 337 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 338 |
+
default = None,
|
| 339 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 340 |
+
)
|
| 341 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 342 |
+
default = -1,
|
| 343 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 344 |
+
)
|
| 345 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 346 |
+
default = None,
|
| 347 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 348 |
+
)
|
| 349 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 350 |
+
default = None,
|
| 351 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 352 |
+
)
|
| 353 |
+
max_seq_length : Optional[int] = field(
|
| 354 |
+
default = None,
|
| 355 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 356 |
+
)
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
output_dir = None,
|
| 360 |
+
per_device_train_batch_size = 4,
|
| 361 |
+
num_train_epochs = 3.0,
|
| 362 |
+
max_steps = -1,
|
| 363 |
+
learning_rate = 5e-05,
|
| 364 |
+
lr_scheduler_type = 'linear',
|
| 365 |
+
lr_scheduler_kwargs = None,
|
| 366 |
+
warmup_steps = 0.1,
|
| 367 |
+
optim = 'adamw_8bit',
|
| 368 |
+
optim_args = None,
|
| 369 |
+
weight_decay = 0.01,
|
| 370 |
+
adam_beta1 = 0.9,
|
| 371 |
+
adam_beta2 = 0.999,
|
| 372 |
+
adam_epsilon = 1e-08,
|
| 373 |
+
optim_target_modules = None,
|
| 374 |
+
gradient_accumulation_steps = 2,
|
| 375 |
+
average_tokens_across_devices = True,
|
| 376 |
+
max_grad_norm = 1.0,
|
| 377 |
+
label_smoothing_factor = 0.0,
|
| 378 |
+
bf16 = False,
|
| 379 |
+
fp16 = False,
|
| 380 |
+
bf16_full_eval = False,
|
| 381 |
+
fp16_full_eval = False,
|
| 382 |
+
tf32 = None,
|
| 383 |
+
gradient_checkpointing = True,
|
| 384 |
+
gradient_checkpointing_kwargs = None,
|
| 385 |
+
torch_compile = False,
|
| 386 |
+
torch_compile_backend = None,
|
| 387 |
+
torch_compile_mode = None,
|
| 388 |
+
use_liger_kernel = False,
|
| 389 |
+
liger_kernel_config = None,
|
| 390 |
+
use_cache = False,
|
| 391 |
+
neftune_noise_alpha = None,
|
| 392 |
+
torch_empty_cache_steps = 250,
|
| 393 |
+
auto_find_batch_size = False,
|
| 394 |
+
logging_strategy = 'steps',
|
| 395 |
+
logging_steps = 1,
|
| 396 |
+
logging_first_step = False,
|
| 397 |
+
log_on_each_node = True,
|
| 398 |
+
logging_nan_inf_filter = False,
|
| 399 |
+
include_num_input_tokens_seen = False,
|
| 400 |
+
log_level = 'passive',
|
| 401 |
+
log_level_replica = 'warning',
|
| 402 |
+
disable_tqdm = None,
|
| 403 |
+
report_to = 'none',
|
| 404 |
+
run_name = None,
|
| 405 |
+
project = 'huggingface',
|
| 406 |
+
trackio_space_id = 'trackio',
|
| 407 |
+
eval_strategy = 'no',
|
| 408 |
+
eval_steps = None,
|
| 409 |
+
eval_delay = 0,
|
| 410 |
+
per_device_eval_batch_size = 4,
|
| 411 |
+
prediction_loss_only = False,
|
| 412 |
+
eval_on_start = False,
|
| 413 |
+
eval_do_concat_batches = True,
|
| 414 |
+
eval_use_gather_object = False,
|
| 415 |
+
eval_accumulation_steps = 2,
|
| 416 |
+
batch_eval_metrics = False,
|
| 417 |
+
save_only_model = False,
|
| 418 |
+
save_strategy = 'steps',
|
| 419 |
+
save_steps = 500,
|
| 420 |
+
save_on_each_node = False,
|
| 421 |
+
save_total_limit = None,
|
| 422 |
+
enable_jit_checkpoint = False,
|
| 423 |
+
push_to_hub = False,
|
| 424 |
+
hub_token = None,
|
| 425 |
+
hub_private_repo = None,
|
| 426 |
+
hub_model_id = None,
|
| 427 |
+
hub_strategy = 'every_save',
|
| 428 |
+
hub_always_push = False,
|
| 429 |
+
hub_revision = None,
|
| 430 |
+
load_best_model_at_end = False,
|
| 431 |
+
metric_for_best_model = None,
|
| 432 |
+
greater_is_better = None,
|
| 433 |
+
ignore_data_skip = False,
|
| 434 |
+
restore_callback_states_from_checkpoint = False,
|
| 435 |
+
full_determinism = False,
|
| 436 |
+
seed = 3407,
|
| 437 |
+
data_seed = 3407,
|
| 438 |
+
use_cpu = False,
|
| 439 |
+
accelerator_config = None,
|
| 440 |
+
parallelism_config = None,
|
| 441 |
+
dataloader_drop_last = False,
|
| 442 |
+
dataloader_num_workers = 0,
|
| 443 |
+
dataloader_pin_memory = True,
|
| 444 |
+
dataloader_persistent_workers = False,
|
| 445 |
+
dataloader_prefetch_factor = None,
|
| 446 |
+
remove_unused_columns = True,
|
| 447 |
+
label_names = None,
|
| 448 |
+
train_sampling_strategy = 'random',
|
| 449 |
+
length_column_name = 'length',
|
| 450 |
+
ddp_find_unused_parameters = None,
|
| 451 |
+
ddp_bucket_cap_mb = None,
|
| 452 |
+
ddp_broadcast_buffers = None,
|
| 453 |
+
ddp_backend = None,
|
| 454 |
+
ddp_timeout = 1800,
|
| 455 |
+
fsdp = None,
|
| 456 |
+
fsdp_config = None,
|
| 457 |
+
deepspeed = None,
|
| 458 |
+
debug = '',
|
| 459 |
+
skip_memory_metrics = True,
|
| 460 |
+
do_train = False,
|
| 461 |
+
do_eval = False,
|
| 462 |
+
do_predict = False,
|
| 463 |
+
resume_from_checkpoint = None,
|
| 464 |
+
warmup_ratio = None,
|
| 465 |
+
logging_dir = None,
|
| 466 |
+
local_rank = -1,
|
| 467 |
+
reward_model_path = None,
|
| 468 |
+
judge = None,
|
| 469 |
+
max_new_tokens = 64,
|
| 470 |
+
max_length = 512,
|
| 471 |
+
temperature = 0.9,
|
| 472 |
+
top_p = 1.0,
|
| 473 |
+
top_k = None,
|
| 474 |
+
min_p = None,
|
| 475 |
+
repetition_penalty = 1.0,
|
| 476 |
+
generation_kwargs = {},
|
| 477 |
+
use_transformers_paged = False,
|
| 478 |
+
cache_implementation = None,
|
| 479 |
+
missing_eos_penalty = None,
|
| 480 |
+
loss_type = 'sigmoid',
|
| 481 |
+
disable_dropout = True,
|
| 482 |
+
use_vllm = False,
|
| 483 |
+
vllm_model_impl = 'vllm',
|
| 484 |
+
vllm_guided_decoding_regex = None,
|
| 485 |
+
vllm_gpu_memory_utilization = 0.55,
|
| 486 |
+
vllm_mode = 'colocate',
|
| 487 |
+
vllm_server_base_url = None,
|
| 488 |
+
vllm_server_host = '0.0.0.0',
|
| 489 |
+
vllm_server_port = 8000,
|
| 490 |
+
vllm_server_timeout = 240.0,
|
| 491 |
+
vllm_tensor_parallel_size = 1,
|
| 492 |
+
ds3_gather_for_generation = True,
|
| 493 |
+
model_init_kwargs = None,
|
| 494 |
+
reward_weights = None,
|
| 495 |
+
dataset_num_proc = None,
|
| 496 |
+
gpu_memory_utilization = None,
|
| 497 |
+
vllm_sampling_params = None,
|
| 498 |
+
unsloth_num_chunks = -1,
|
| 499 |
+
unsloth_logit_chunk_multiplier = None,
|
| 500 |
+
unsloth_grpo_mini_batch = None,
|
| 501 |
+
max_seq_length = None,
|
| 502 |
+
**kwargs,
|
| 503 |
+
):
|
| 504 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 505 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 506 |
+
if num_train_epochs is None:
|
| 507 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 508 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 509 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 510 |
+
save_strategy = 'no'
|
| 511 |
+
import multiprocessing as _mp
|
| 512 |
+
if _mp.get_start_method() != 'fork':
|
| 513 |
+
dataset_num_proc = None
|
| 514 |
+
elif dataset_num_proc is None:
|
| 515 |
+
import psutil
|
| 516 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 517 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 518 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 519 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 520 |
+
if temperature <= 0:
|
| 521 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 522 |
+
elif temperature >= 10:
|
| 523 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
super().__init__(
|
| 527 |
+
output_dir = output_dir,
|
| 528 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 529 |
+
num_train_epochs = num_train_epochs,
|
| 530 |
+
max_steps = max_steps,
|
| 531 |
+
learning_rate = learning_rate,
|
| 532 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 533 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 534 |
+
warmup_steps = warmup_steps,
|
| 535 |
+
optim = optim,
|
| 536 |
+
optim_args = optim_args,
|
| 537 |
+
weight_decay = weight_decay,
|
| 538 |
+
adam_beta1 = adam_beta1,
|
| 539 |
+
adam_beta2 = adam_beta2,
|
| 540 |
+
adam_epsilon = adam_epsilon,
|
| 541 |
+
optim_target_modules = optim_target_modules,
|
| 542 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 543 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 544 |
+
max_grad_norm = max_grad_norm,
|
| 545 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 546 |
+
bf16 = bf16,
|
| 547 |
+
fp16 = fp16,
|
| 548 |
+
bf16_full_eval = bf16_full_eval,
|
| 549 |
+
fp16_full_eval = fp16_full_eval,
|
| 550 |
+
tf32 = tf32,
|
| 551 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 552 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 553 |
+
torch_compile = torch_compile,
|
| 554 |
+
torch_compile_backend = torch_compile_backend,
|
| 555 |
+
torch_compile_mode = torch_compile_mode,
|
| 556 |
+
use_liger_kernel = use_liger_kernel,
|
| 557 |
+
liger_kernel_config = liger_kernel_config,
|
| 558 |
+
use_cache = use_cache,
|
| 559 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 560 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 561 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 562 |
+
logging_strategy = logging_strategy,
|
| 563 |
+
logging_steps = logging_steps,
|
| 564 |
+
logging_first_step = logging_first_step,
|
| 565 |
+
log_on_each_node = log_on_each_node,
|
| 566 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 567 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 568 |
+
log_level = log_level,
|
| 569 |
+
log_level_replica = log_level_replica,
|
| 570 |
+
disable_tqdm = disable_tqdm,
|
| 571 |
+
report_to = report_to,
|
| 572 |
+
run_name = run_name,
|
| 573 |
+
project = project,
|
| 574 |
+
trackio_space_id = trackio_space_id,
|
| 575 |
+
eval_strategy = eval_strategy,
|
| 576 |
+
eval_steps = eval_steps,
|
| 577 |
+
eval_delay = eval_delay,
|
| 578 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 579 |
+
prediction_loss_only = prediction_loss_only,
|
| 580 |
+
eval_on_start = eval_on_start,
|
| 581 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 582 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 583 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 584 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 585 |
+
save_only_model = save_only_model,
|
| 586 |
+
save_strategy = save_strategy,
|
| 587 |
+
save_steps = save_steps,
|
| 588 |
+
save_on_each_node = save_on_each_node,
|
| 589 |
+
save_total_limit = save_total_limit,
|
| 590 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 591 |
+
push_to_hub = push_to_hub,
|
| 592 |
+
hub_token = hub_token,
|
| 593 |
+
hub_private_repo = hub_private_repo,
|
| 594 |
+
hub_model_id = hub_model_id,
|
| 595 |
+
hub_strategy = hub_strategy,
|
| 596 |
+
hub_always_push = hub_always_push,
|
| 597 |
+
hub_revision = hub_revision,
|
| 598 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 599 |
+
metric_for_best_model = metric_for_best_model,
|
| 600 |
+
greater_is_better = greater_is_better,
|
| 601 |
+
ignore_data_skip = ignore_data_skip,
|
| 602 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 603 |
+
full_determinism = full_determinism,
|
| 604 |
+
seed = seed,
|
| 605 |
+
data_seed = data_seed,
|
| 606 |
+
use_cpu = use_cpu,
|
| 607 |
+
accelerator_config = accelerator_config,
|
| 608 |
+
parallelism_config = parallelism_config,
|
| 609 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 610 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 611 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 612 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 613 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 614 |
+
remove_unused_columns = remove_unused_columns,
|
| 615 |
+
label_names = label_names,
|
| 616 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 617 |
+
length_column_name = length_column_name,
|
| 618 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 619 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 620 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 621 |
+
ddp_backend = ddp_backend,
|
| 622 |
+
ddp_timeout = ddp_timeout,
|
| 623 |
+
fsdp = fsdp,
|
| 624 |
+
fsdp_config = fsdp_config,
|
| 625 |
+
deepspeed = deepspeed,
|
| 626 |
+
debug = debug,
|
| 627 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 628 |
+
do_train = do_train,
|
| 629 |
+
do_eval = do_eval,
|
| 630 |
+
do_predict = do_predict,
|
| 631 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 632 |
+
warmup_ratio = warmup_ratio,
|
| 633 |
+
logging_dir = logging_dir,
|
| 634 |
+
local_rank = local_rank,
|
| 635 |
+
reward_model_path = reward_model_path,
|
| 636 |
+
judge = judge,
|
| 637 |
+
max_new_tokens = max_new_tokens,
|
| 638 |
+
max_length = max_length,
|
| 639 |
+
temperature = temperature,
|
| 640 |
+
top_p = top_p,
|
| 641 |
+
top_k = top_k,
|
| 642 |
+
min_p = min_p,
|
| 643 |
+
repetition_penalty = repetition_penalty,
|
| 644 |
+
generation_kwargs = generation_kwargs,
|
| 645 |
+
use_transformers_paged = use_transformers_paged,
|
| 646 |
+
cache_implementation = cache_implementation,
|
| 647 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 648 |
+
loss_type = loss_type,
|
| 649 |
+
disable_dropout = disable_dropout,
|
| 650 |
+
use_vllm = use_vllm,
|
| 651 |
+
vllm_model_impl = vllm_model_impl,
|
| 652 |
+
vllm_guided_decoding_regex = vllm_guided_decoding_regex,
|
| 653 |
+
vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
|
| 654 |
+
vllm_mode = vllm_mode,
|
| 655 |
+
vllm_server_base_url = vllm_server_base_url,
|
| 656 |
+
vllm_server_host = vllm_server_host,
|
| 657 |
+
vllm_server_port = vllm_server_port,
|
| 658 |
+
vllm_server_timeout = vllm_server_timeout,
|
| 659 |
+
vllm_tensor_parallel_size = vllm_tensor_parallel_size,
|
| 660 |
+
ds3_gather_for_generation = ds3_gather_for_generation,
|
| 661 |
+
model_init_kwargs = model_init_kwargs,
|
| 662 |
+
reward_weights = reward_weights,
|
| 663 |
+
dataset_num_proc = dataset_num_proc,
|
| 664 |
+
gpu_memory_utilization = gpu_memory_utilization,**kwargs)
|
| 665 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 666 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 667 |
+
if unsloth_grpo_mini_batch is not None:
|
| 668 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 669 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 670 |
+
else:
|
| 671 |
+
raise ValueError(
|
| 672 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 673 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 674 |
+
)
|
| 675 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 676 |
+
self.max_seq_length = max_seq_length
|
| 677 |
+
|
| 678 |
+
pass
|
| 679 |
+
|
| 680 |
+
class _UnslothXPOTrainer(OnlineDPOTrainer):
|
| 681 |
+
""""""
|
| 682 |
+
|
| 683 |
+
_tag_names = ["trl", "xpo"]
|
| 684 |
+
_name = "XPO"
|
| 685 |
+
_paper = {
|
| 686 |
+
"title": "Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
|
| 687 |
+
"id": "2405.21046",
|
| 688 |
+
# docstyle-ignore
|
| 689 |
+
"citation": textwrap.dedent("""\
|
| 690 |
+
@article{jung2024binary,
|
| 691 |
+
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
|
| 692 |
+
author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
|
| 693 |
+
year = 2024,
|
| 694 |
+
eprint = {arXiv:2405.21046}
|
| 695 |
+
}"""),
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
def __init__(
|
| 699 |
+
self,
|
| 700 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 701 |
+
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
| 702 |
+
reward_funcs: Optional[nn.Module] = None,
|
| 703 |
+
judge: Optional[BasePairwiseJudge] = None,
|
| 704 |
+
args: Optional[XPOConfig] = None,
|
| 705 |
+
data_collator: Optional[Callable] = None,
|
| 706 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 707 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 708 |
+
processing_class: Optional[
|
| 709 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 710 |
+
] = None,
|
| 711 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
| 712 |
+
peft_config: Optional[dict] = None,
|
| 713 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 714 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 715 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 716 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 717 |
+
# Deprecated parameters
|
| 718 |
+
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 719 |
+
) -> None:
|
| 720 |
+
super().__init__(
|
| 721 |
+
model=model,
|
| 722 |
+
ref_model=ref_model,
|
| 723 |
+
judge=judge,
|
| 724 |
+
reward_funcs=reward_funcs,
|
| 725 |
+
reward_model=reward_model,
|
| 726 |
+
args=args,
|
| 727 |
+
data_collator=data_collator,
|
| 728 |
+
train_dataset=train_dataset,
|
| 729 |
+
eval_dataset=eval_dataset,
|
| 730 |
+
processing_class=processing_class,
|
| 731 |
+
reward_processing_classes=reward_processing_classes,
|
| 732 |
+
peft_config=peft_config,
|
| 733 |
+
compute_metrics=compute_metrics,
|
| 734 |
+
callbacks=callbacks,
|
| 735 |
+
optimizers=optimizers,
|
| 736 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
self._alpha = self.args.alpha
|
| 740 |
+
|
| 741 |
+
# Overwrite the stats dictionary to include XPO specific statistics
|
| 742 |
+
self.stats = {
|
| 743 |
+
# Remove "non_score_reward", "rlhf_reward", "scores"
|
| 744 |
+
# Add "loss/dpo", "loss/xpo"
|
| 745 |
+
"loss/dpo": [],
|
| 746 |
+
"loss/xpo": [],
|
| 747 |
+
"objective/kl": [],
|
| 748 |
+
"objective/entropy": [],
|
| 749 |
+
"rewards/chosen": [],
|
| 750 |
+
"rewards/rejected": [],
|
| 751 |
+
"rewards/accuracies": [],
|
| 752 |
+
"rewards/margins": [],
|
| 753 |
+
"logps/chosen": [],
|
| 754 |
+
"logps/rejected": [],
|
| 755 |
+
# Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
|
| 756 |
+
"val/model_contain_eos_token": [],
|
| 757 |
+
"val/ref_contain_eos_token": [],
|
| 758 |
+
"alpha": [],
|
| 759 |
+
"beta": [],
|
| 760 |
+
}
|
| 761 |
+
if self.reward_funcs is not None:
|
| 762 |
+
if len(self.reward_funcs) != 1:
|
| 763 |
+
raise ValueError("XPOTrainer only supports one reward function/model.")
|
| 764 |
+
self.reward_funcs = self.reward_funcs[0]
|
| 765 |
+
self.stats["objective/model_scores"] = []
|
| 766 |
+
self.stats["objective/ref_scores"] = []
|
| 767 |
+
self.stats["objective/scores_margin"] = []
|
| 768 |
+
|
| 769 |
+
@property
|
| 770 |
+
def alpha(self):
|
| 771 |
+
if isinstance(self._alpha, list):
|
| 772 |
+
epoch = self.state.epoch
|
| 773 |
+
return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
|
| 774 |
+
else:
|
| 775 |
+
return self._alpha
|
| 776 |
+
|
| 777 |
+
def _generate_completions(self, prompts, model):
|
| 778 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
|
| 779 |
+
model_output = unwrapped_policy_model_for_gen.generate(
|
| 780 |
+
input_ids=prompts["input_ids"],
|
| 781 |
+
attention_mask=prompts["attention_mask"],
|
| 782 |
+
generation_config=self.generation_config,
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
actual_model_for_ref_generation: torch.nn.Module
|
| 786 |
+
if self.ref_model is None:
|
| 787 |
+
unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
|
| 788 |
+
|
| 789 |
+
if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
|
| 790 |
+
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
|
| 791 |
+
else:
|
| 792 |
+
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
|
| 793 |
+
else:
|
| 794 |
+
actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
|
| 795 |
+
|
| 796 |
+
with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
|
| 797 |
+
ref_output = final_ref_model_for_gen.generate(
|
| 798 |
+
input_ids=prompts["input_ids"],
|
| 799 |
+
attention_mask=prompts["attention_mask"],
|
| 800 |
+
generation_config=self.generation_config,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
return model_output, ref_output
|
| 804 |
+
|
| 805 |
+
def _process_completions(self, model_output, ref_output, prompts):
|
| 806 |
+
context_length = prompts["input_ids"].shape[1]
|
| 807 |
+
|
| 808 |
+
# Process model completions
|
| 809 |
+
model_completion_ids = model_output[:, context_length:]
|
| 810 |
+
model_completion_ids, model_completion_mask = truncate_right(
|
| 811 |
+
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 812 |
+
)
|
| 813 |
+
model_data = {
|
| 814 |
+
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
| 815 |
+
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
| 816 |
+
"raw": prompts["raw"],
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
# Process reference model completions
|
| 820 |
+
ref_completion_ids = ref_output[:, context_length:]
|
| 821 |
+
ref_completion_ids, ref_completion_mask = truncate_right(
|
| 822 |
+
ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 823 |
+
)
|
| 824 |
+
ref_data = {
|
| 825 |
+
"input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
|
| 826 |
+
"attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
|
| 827 |
+
"raw": prompts["raw"],
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
return model_data, ref_data
|
| 831 |
+
|
| 832 |
+
def _compute_rewards(self, model_data, ref_data, context_length):
|
| 833 |
+
with torch.no_grad():
|
| 834 |
+
_, model_scores, _ = get_reward(
|
| 835 |
+
self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 836 |
+
)
|
| 837 |
+
_, ref_scores, _ = get_reward(
|
| 838 |
+
self.reward_funcs, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
# Apply EOS penalty if needed
|
| 842 |
+
if self.args.missing_eos_penalty is not None:
|
| 843 |
+
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 844 |
+
ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 845 |
+
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
| 846 |
+
ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
|
| 847 |
+
|
| 848 |
+
return model_scores, ref_scores
|
| 849 |
+
|
| 850 |
+
def _compute_judge(self, model_data, ref_data, context_length):
|
| 851 |
+
prompts = model_data["raw"]
|
| 852 |
+
model_data_completions = self.processing_class.batch_decode(
|
| 853 |
+
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 854 |
+
)
|
| 855 |
+
model_data_completions = [completion.strip() for completion in model_data_completions]
|
| 856 |
+
|
| 857 |
+
ref_data_completions = self.processing_class.batch_decode(
|
| 858 |
+
ref_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 859 |
+
)
|
| 860 |
+
ref_data_completions = [completion.strip() for completion in ref_data_completions]
|
| 861 |
+
|
| 862 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 863 |
+
model_data_completions = [
|
| 864 |
+
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
| 865 |
+
]
|
| 866 |
+
environment = jinja2.Environment()
|
| 867 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 868 |
+
prompts = [template.render(messages=message) for message in prompts]
|
| 869 |
+
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
| 870 |
+
|
| 871 |
+
ref_data_completions = [
|
| 872 |
+
[{"role": "assistant", "content": completion}] for completion in ref_data_completions
|
| 873 |
+
]
|
| 874 |
+
ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
|
| 875 |
+
|
| 876 |
+
ranks_of_first_completion = self.judge.judge(
|
| 877 |
+
prompts,
|
| 878 |
+
list(zip(model_data_completions, ref_data_completions)),
|
| 879 |
+
)
|
| 880 |
+
# convert ranks to a True/False mask:
|
| 881 |
+
# when rank == 0, it means the first completion is the best
|
| 882 |
+
# when rank == 1, it means the second completion is the best
|
| 883 |
+
return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
|
| 884 |
+
|
| 885 |
+
def _compute_logprobs(self, model, model_data, ref_data, context_length):
|
| 886 |
+
def compute_logprobs_for_data(m, data):
|
| 887 |
+
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
| 888 |
+
logits = output.logits[:, context_length - 1 : -1]
|
| 889 |
+
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
| 890 |
+
return token_logprobs
|
| 891 |
+
|
| 892 |
+
# Compute logprobs for model completions
|
| 893 |
+
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 894 |
+
# Compute logprobs for model on reference completions (for XPO loss)
|
| 895 |
+
model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
| 896 |
+
|
| 897 |
+
# Compute logprobs for reference model completions
|
| 898 |
+
with torch.no_grad():
|
| 899 |
+
if self.ref_model is None:
|
| 900 |
+
with model.disable_adapter():
|
| 901 |
+
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 902 |
+
ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
| 903 |
+
else:
|
| 904 |
+
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
| 905 |
+
ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
|
| 906 |
+
|
| 907 |
+
# Mask padding tokens
|
| 908 |
+
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
| 909 |
+
ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
|
| 910 |
+
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 911 |
+
model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
| 912 |
+
ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
| 913 |
+
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 914 |
+
|
| 915 |
+
return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
|
| 916 |
+
|
| 917 |
+
def _compute_losses(
|
| 918 |
+
self,
|
| 919 |
+
model_logprobs_model_data,
|
| 920 |
+
model_logprobs_ref_data,
|
| 921 |
+
ref_logprobs_ref_data,
|
| 922 |
+
ref_logprobs_model_data,
|
| 923 |
+
chosen_mask,
|
| 924 |
+
):
|
| 925 |
+
# Compute log probs
|
| 926 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 927 |
+
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
| 928 |
+
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
| 929 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 930 |
+
|
| 931 |
+
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 932 |
+
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 933 |
+
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
| 934 |
+
|
| 935 |
+
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 936 |
+
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 937 |
+
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
| 938 |
+
|
| 939 |
+
# Compute logits as the difference between chosen and rejected log ratios
|
| 940 |
+
logits = chosen_log_ratios - rejected_log_ratios
|
| 941 |
+
|
| 942 |
+
if self.args.loss_type == "sigmoid":
|
| 943 |
+
dpo_losses = -F.logsigmoid(self.beta * logits)
|
| 944 |
+
elif self.args.loss_type == "ipo":
|
| 945 |
+
dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 946 |
+
else:
|
| 947 |
+
raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
|
| 948 |
+
|
| 949 |
+
# Compute XPO specific loss
|
| 950 |
+
xpo_losses = self.alpha * model_logprobs_ref_data_sum
|
| 951 |
+
|
| 952 |
+
# Total loss
|
| 953 |
+
loss = (dpo_losses + xpo_losses).mean()
|
| 954 |
+
|
| 955 |
+
return loss, dpo_losses, xpo_losses
|
| 956 |
+
|
| 957 |
+
def _log_statistics(
|
| 958 |
+
self,
|
| 959 |
+
model_data,
|
| 960 |
+
ref_data,
|
| 961 |
+
model_logprobs_model_data,
|
| 962 |
+
model_logprobs_ref_data,
|
| 963 |
+
ref_logprobs_ref_data,
|
| 964 |
+
ref_logprobs_model_data,
|
| 965 |
+
chosen_mask,
|
| 966 |
+
dpo_losses,
|
| 967 |
+
xpo_losses,
|
| 968 |
+
context_length,
|
| 969 |
+
model_scores=None,
|
| 970 |
+
ref_scores=None,
|
| 971 |
+
):
|
| 972 |
+
# Helper function to gather and compute mean
|
| 973 |
+
def gather_mean(tensor):
|
| 974 |
+
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
| 975 |
+
|
| 976 |
+
# Log losses
|
| 977 |
+
self.stats["loss/dpo"].append(gather_mean(dpo_losses))
|
| 978 |
+
self.stats["loss/xpo"].append(gather_mean(xpo_losses))
|
| 979 |
+
|
| 980 |
+
# Log scores
|
| 981 |
+
if self.reward_funcs is not None:
|
| 982 |
+
self.stats["objective/model_scores"].append(gather_mean(model_scores))
|
| 983 |
+
self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
|
| 984 |
+
self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
|
| 985 |
+
|
| 986 |
+
# Log logprobs
|
| 987 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 988 |
+
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
| 989 |
+
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
| 990 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 991 |
+
|
| 992 |
+
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 993 |
+
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 994 |
+
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
| 995 |
+
|
| 996 |
+
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 997 |
+
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 998 |
+
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
| 999 |
+
|
| 1000 |
+
self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
|
| 1001 |
+
self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
|
| 1002 |
+
|
| 1003 |
+
# Log rewards
|
| 1004 |
+
# Compute various statistics
|
| 1005 |
+
chosen_rewards = chosen_log_ratios * self.beta
|
| 1006 |
+
rejected_rewards = rejected_log_ratios * self.beta
|
| 1007 |
+
self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
|
| 1008 |
+
self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
|
| 1009 |
+
|
| 1010 |
+
# Calculate KL divergence for model and ref data
|
| 1011 |
+
kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
|
| 1012 |
+
kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
|
| 1013 |
+
mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
|
| 1014 |
+
self.stats["objective/kl"].append(gather_mean(mean_kl))
|
| 1015 |
+
|
| 1016 |
+
# Calculate entropy for model and ref data
|
| 1017 |
+
entropy_model_data = -model_logprobs_model_data.sum(1)
|
| 1018 |
+
entropy_ref_data = -model_logprobs_ref_data.sum(1)
|
| 1019 |
+
mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
|
| 1020 |
+
self.stats["objective/entropy"].append(gather_mean(mean_entropy))
|
| 1021 |
+
|
| 1022 |
+
# Calculate margins
|
| 1023 |
+
margin = chosen_rewards - rejected_rewards
|
| 1024 |
+
self.stats["rewards/margins"].append(gather_mean(margin.mean()))
|
| 1025 |
+
|
| 1026 |
+
# Calculate accuracy
|
| 1027 |
+
accuracy = (margin > 0).float()
|
| 1028 |
+
self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
|
| 1029 |
+
|
| 1030 |
+
# Log EOS token statistics
|
| 1031 |
+
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 1032 |
+
ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 1033 |
+
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
| 1034 |
+
self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
|
| 1035 |
+
|
| 1036 |
+
# Log alpha and beta
|
| 1037 |
+
self.stats["alpha"].append(self.alpha)
|
| 1038 |
+
self.stats["beta"].append(self.beta)
|
| 1039 |
+
|
| 1040 |
+
def training_step(
|
| 1041 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 1042 |
+
) -> torch.Tensor:
|
| 1043 |
+
model.train()
|
| 1044 |
+
|
| 1045 |
+
# Apply chat template and tokenize the input
|
| 1046 |
+
batch_size = len(next(iter(inputs.values())))
|
| 1047 |
+
prompts = inputs["prompt"]
|
| 1048 |
+
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
| 1049 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 1050 |
+
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 1051 |
+
inputs = self.data_collator(inputs)
|
| 1052 |
+
|
| 1053 |
+
# need the prompt_ only
|
| 1054 |
+
inputs = self._prepare_inputs(inputs)
|
| 1055 |
+
context_length = inputs["prompt_input_ids"].shape[1]
|
| 1056 |
+
prompts = {
|
| 1057 |
+
"input_ids": inputs["prompt_input_ids"],
|
| 1058 |
+
"attention_mask": inputs["prompt_attention_mask"],
|
| 1059 |
+
"raw": prompts,
|
| 1060 |
+
}
|
| 1061 |
+
del inputs
|
| 1062 |
+
|
| 1063 |
+
# Sample completions from both the model and the reference model
|
| 1064 |
+
model_output, ref_output = self._generate_completions(prompts, model)
|
| 1065 |
+
|
| 1066 |
+
# Process model completions
|
| 1067 |
+
model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
|
| 1068 |
+
|
| 1069 |
+
# Compute rewards
|
| 1070 |
+
if self.reward_funcs is not None:
|
| 1071 |
+
model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
|
| 1072 |
+
chosen_mask = model_scores >= ref_scores
|
| 1073 |
+
else:
|
| 1074 |
+
model_scores, ref_scores = None, None
|
| 1075 |
+
chosen_mask = self._compute_judge(model_data, ref_data, context_length)
|
| 1076 |
+
|
| 1077 |
+
# Compute logprobs
|
| 1078 |
+
model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
|
| 1079 |
+
self._compute_logprobs(model, model_data, ref_data, context_length)
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
# Compute loss
|
| 1083 |
+
loss, dpo_losses, xpo_losses = self._compute_losses(
|
| 1084 |
+
model_logprobs_model_data,
|
| 1085 |
+
model_logprobs_ref_data,
|
| 1086 |
+
ref_logprobs_ref_data,
|
| 1087 |
+
ref_logprobs_model_data,
|
| 1088 |
+
chosen_mask,
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
# Log everything
|
| 1092 |
+
self._log_statistics(
|
| 1093 |
+
model_data,
|
| 1094 |
+
ref_data,
|
| 1095 |
+
model_logprobs_model_data.detach(),
|
| 1096 |
+
model_logprobs_ref_data.detach(),
|
| 1097 |
+
ref_logprobs_ref_data,
|
| 1098 |
+
ref_logprobs_model_data,
|
| 1099 |
+
chosen_mask,
|
| 1100 |
+
dpo_losses.detach(),
|
| 1101 |
+
xpo_losses.detach(),
|
| 1102 |
+
context_length,
|
| 1103 |
+
model_scores,
|
| 1104 |
+
ref_scores,
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
if (
|
| 1108 |
+
self.args.torch_empty_cache_steps is not None
|
| 1109 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 1110 |
+
):
|
| 1111 |
+
empty_cache()
|
| 1112 |
+
|
| 1113 |
+
kwargs = {}
|
| 1114 |
+
# For LOMO optimizers you need to explicitly use the learning rate
|
| 1115 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 1116 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
| 1117 |
+
|
| 1118 |
+
if self.args.n_gpu > 1:
|
| 1119 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 1120 |
+
|
| 1121 |
+
self.accelerator.backward(loss, **kwargs)
|
| 1122 |
+
|
| 1123 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
| 1124 |
+
class UnslothXPOTrainer(_UnslothXPOTrainer):
|
| 1125 |
+
"""
|
| 1126 |
+
|
| 1127 |
+
Trainer for Exploratory Preference Optimization (XPO).
|
| 1128 |
+
|
| 1129 |
+
It is implemented as a subclass of [`OnlineDPOTrainer`].
|
| 1130 |
+
|
| 1131 |
+
Args:
|
| 1132 |
+
model ([`~transformers.PreTrainedModel`]):
|
| 1133 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
| 1134 |
+
ref_model ([`PreTrainedModelWrapper`]):
|
| 1135 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
|
| 1136 |
+
and loss. If no reference model is provided, the trainer will create a reference model with the same
|
| 1137 |
+
architecture as the model to be optimized.
|
| 1138 |
+
reward_funcs ([`~transformers.PreTrainedModel`]):
|
| 1139 |
+
The reward model to score completions with, preferably an
|
| 1140 |
+
[`~transformers.AutoModelForSequenceClassification`].
|
| 1141 |
+
judge ([`BasePairwiseJudge`]):
|
| 1142 |
+
The judge to use for pairwise comparison of model completions.
|
| 1143 |
+
args ([`XPOConfig`]):
|
| 1144 |
+
The XPO config arguments to use for training.
|
| 1145 |
+
data_collator ([`~transformers.DataCollator`]):
|
| 1146 |
+
The data collator to use for training. If None is specified, the default data collator
|
| 1147 |
+
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
| 1148 |
+
sequences in the batch, given a dataset of paired sequences.
|
| 1149 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 1150 |
+
The dataset to use for training.
|
| 1151 |
+
eval_dataset ([`~datasets.Dataset`]):
|
| 1152 |
+
The dataset to use for evaluation.
|
| 1153 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 1154 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1155 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1156 |
+
reuse the fine-tuned model.
|
| 1157 |
+
peft_config (`dict`):
|
| 1158 |
+
The peft config to use for training.
|
| 1159 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1160 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
|
| 1161 |
+
metric values.
|
| 1162 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1163 |
+
The callbacks to use for training.
|
| 1164 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1165 |
+
The optimizer and scheduler to use for training.
|
| 1166 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1167 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1168 |
+
|
| 1169 |
+
reward_model:
|
| 1170 |
+
|
| 1171 |
+
<Deprecated version="0.22.0">
|
| 1172 |
+
|
| 1173 |
+
This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
|
| 1174 |
+
|
| 1175 |
+
</Deprecated>
|
| 1176 |
+
|
| 1177 |
+
"""
|
| 1178 |
+
def __init__(
|
| 1179 |
+
self,
|
| 1180 |
+
model = None,
|
| 1181 |
+
ref_model = None,
|
| 1182 |
+
reward_funcs = None,
|
| 1183 |
+
judge = None,
|
| 1184 |
+
args = None,
|
| 1185 |
+
data_collator = None,
|
| 1186 |
+
train_dataset = None,
|
| 1187 |
+
eval_dataset = None,
|
| 1188 |
+
processing_class = None,
|
| 1189 |
+
reward_processing_classes = None,
|
| 1190 |
+
peft_config = None,
|
| 1191 |
+
compute_metrics = None,
|
| 1192 |
+
callbacks = None,
|
| 1193 |
+
preprocess_logits_for_metrics = None,
|
| 1194 |
+
reward_model = None,
|
| 1195 |
+
**kwargs
|
| 1196 |
+
):
|
| 1197 |
+
if args is None: args = UnslothXPOConfig()
|
| 1198 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1199 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1200 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1201 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1202 |
+
force_float32 = False
|
| 1203 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1204 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1205 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1206 |
+
force_float32 = True
|
| 1207 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1208 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1209 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1210 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1211 |
+
dtype = _get_dtype(dtype)
|
| 1212 |
+
float16 = dtype == torch.float16
|
| 1213 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1214 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1215 |
+
if force_float32:
|
| 1216 |
+
# Forced float32 training
|
| 1217 |
+
args.fp16 = False
|
| 1218 |
+
args.bf16 = False
|
| 1219 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1220 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1221 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1222 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1223 |
+
# Mixed precision training
|
| 1224 |
+
args.fp16 = float16
|
| 1225 |
+
args.bf16 = not float16
|
| 1226 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1227 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1228 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1229 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1230 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1231 |
+
args.fp16 = False
|
| 1232 |
+
args.bf16 = False
|
| 1233 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1234 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1235 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1236 |
+
|
| 1237 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1238 |
+
args.eval_strategy = 'steps'
|
| 1239 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1240 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1241 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1242 |
+
from transformers import __version__ as transformers_version
|
| 1243 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1244 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1245 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1246 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1247 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1248 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1249 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1250 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1251 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1252 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1253 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1254 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1255 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1256 |
+
if force_float32:
|
| 1257 |
+
args.bf16_full_eval = False
|
| 1258 |
+
args.fp16_full_eval = False
|
| 1259 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1260 |
+
args.bf16_full_eval = True
|
| 1261 |
+
args.fp16_full_eval = False
|
| 1262 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1263 |
+
args.bf16_full_eval = args.bf16
|
| 1264 |
+
args.fp16_full_eval = args.fp16
|
| 1265 |
+
_output_logits = False
|
| 1266 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1267 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1268 |
+
if _output_logits:
|
| 1269 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1270 |
+
if model is not None:
|
| 1271 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1272 |
+
if _warnings_issued is None:
|
| 1273 |
+
model.warnings_issued = {}
|
| 1274 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1275 |
+
try:
|
| 1276 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1277 |
+
except Exception:
|
| 1278 |
+
model.warnings_issued = {}
|
| 1279 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1280 |
+
pass
|
| 1281 |
+
else:
|
| 1282 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1283 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1284 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1285 |
+
max_seq_length = model.max_seq_length
|
| 1286 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1287 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1288 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1289 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1290 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1291 |
+
args.max_seq_length = model_max_seq_length
|
| 1292 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1293 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1294 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1295 |
+
if 'processing_class' in locals():
|
| 1296 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1297 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1298 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1299 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1300 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1301 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1302 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1303 |
+
__tokenizer,
|
| 1304 |
+
mlm = False,
|
| 1305 |
+
mlm_probability = 0.0,
|
| 1306 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1307 |
+
)
|
| 1308 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1309 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1310 |
+
__tokenizer,
|
| 1311 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1312 |
+
)
|
| 1313 |
+
else:
|
| 1314 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1315 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1316 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1317 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1318 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1319 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1320 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1321 |
+
__tokenizer.tokenizer,
|
| 1322 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1323 |
+
)
|
| 1324 |
+
else:
|
| 1325 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1326 |
+
__tokenizer.tokenizer,
|
| 1327 |
+
mlm = False,
|
| 1328 |
+
mlm_probability = 0.0,
|
| 1329 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1330 |
+
)
|
| 1331 |
+
other_metrics = []
|
| 1332 |
+
|
| 1333 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1334 |
+
PatchRLStatistics('xpo_trainer', other_metrics)
|
| 1335 |
+
|
| 1336 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1337 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1338 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1339 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1340 |
+
args._n_gpu = 1
|
| 1341 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1342 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1343 |
+
super().__init__(
|
| 1344 |
+
model = model,
|
| 1345 |
+
ref_model = ref_model,
|
| 1346 |
+
reward_funcs = reward_funcs,
|
| 1347 |
+
judge = judge,
|
| 1348 |
+
args = args,
|
| 1349 |
+
data_collator = data_collator,
|
| 1350 |
+
train_dataset = train_dataset,
|
| 1351 |
+
eval_dataset = eval_dataset,
|
| 1352 |
+
processing_class = processing_class,
|
| 1353 |
+
reward_processing_classes = reward_processing_classes,
|
| 1354 |
+
peft_config = peft_config,
|
| 1355 |
+
compute_metrics = compute_metrics,
|
| 1356 |
+
callbacks = callbacks,
|
| 1357 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1358 |
+
reward_model = reward_model,**kwargs)
|
| 1359 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1360 |
+
model.for_inference()
|
| 1361 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1362 |
+
self.neftune_hook_handle.remove()
|
| 1363 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1364 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1365 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1366 |
+
pass
|
| 1367 |
+
if hasattr(self, 'accelerator'):
|
| 1368 |
+
scaler = self.accelerator.scaler
|
| 1369 |
+
current_model = model
|
| 1370 |
+
while hasattr(current_model, 'model'):
|
| 1371 |
+
current_model.accelerator_scaler = scaler
|
| 1372 |
+
current_model = current_model.model
|
| 1373 |
+
current_model.accelerator_scaler = scaler
|
| 1374 |
+
pass
|
| 1375 |
+
if hasattr(self, 'train'):
|
| 1376 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1377 |
+
pass
|
| 1378 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1379 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1380 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1381 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1382 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1383 |
+
pass
|
| 1384 |
+
|
| 1385 |
+
pass
|