Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Upload folder using huggingface_hub
Browse files- H100_JUPYTER_SETUP.md +87 -17
- README.md +18 -1
- _dashboard_state.json +108 -88
- colab_train_unsloth.ipynb +128 -0
- inference.ipynb +27 -28
- pyproject.toml +7 -4
- requirements-unsloth.txt +13 -0
- run_agent_unsloth.py +1 -1
- scripts/register_kernel_312.sh +17 -0
- train.ipynb +32 -19
- train_unsloth.ipynb +149 -0
- training_script.py +228 -35
- training_unsloth.py +67 -29
- unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
- unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
- unsloth_compiled_cache/UnslothCPOTrainer.py +1936 -0
- unsloth_compiled_cache/UnslothGKDTrainer.py +1287 -0
- unsloth_compiled_cache/UnslothKTOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothORPOTrainer.py +1860 -0
- unsloth_compiled_cache/UnslothPPOTrainer.py +1634 -0
- unsloth_compiled_cache/UnslothPRMTrainer.py +1109 -0
- unsloth_compiled_cache/UnslothRewardTrainer.py +1327 -0
- unsloth_compiled_cache/UnslothSFTTrainer.py +1588 -0
- unsloth_compiled_cache/moe_utils.py +1320 -0
- uv.lock +0 -0
H100_JUPYTER_SETUP.md
CHANGED
|
@@ -4,7 +4,7 @@ This guide walks you through setting up the OpenEnv Bio Experiment environment o
|
|
| 4 |
|
| 5 |
## Prerequisites
|
| 6 |
|
| 7 |
-
- **Python**
|
| 8 |
- **uv** – fast Python package manager ([install instructions](#installing-uv))
|
| 9 |
- **NVIDIA driver** ≥ 535.104.05 (usually pre-installed on H100 instances)
|
| 10 |
- **CUDA** – H100 uses CUDA 12.x; PyTorch wheels bundle the runtime, so a separate CUDA Toolkit is not required
|
|
@@ -38,14 +38,16 @@ cd OpenENV-Hackathon
|
|
| 38 |
|
| 39 |
### 2. Use uv's auto PyTorch backend
|
| 40 |
|
| 41 |
-
|
| 42 |
|
| 43 |
```bash
|
| 44 |
-
# Install everything: core + training (TRL, transformers, torch
|
| 45 |
UV_TORCH_BACKEND=cu128 uv sync --extra train
|
| 46 |
|
| 47 |
-
# Add
|
| 48 |
-
uv
|
|
|
|
|
|
|
| 49 |
```
|
| 50 |
|
| 51 |
If `UV_TORCH_BACKEND=cu128` fails (e.g., cu128 wheels not available yet), try:
|
|
@@ -57,12 +59,20 @@ UV_TORCH_BACKEND=cu126 uv sync --extra train
|
|
| 57 |
### 3. Register the environment as a Jupyter kernel
|
| 58 |
|
| 59 |
```bash
|
| 60 |
-
uv run python -m ipykernel install --user --name openenv-bio --display-name "OpenEnv Bio (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
```
|
| 62 |
|
|
|
|
|
|
|
| 63 |
### 4. Verify CUDA
|
| 64 |
|
| 65 |
-
In a new Jupyter notebook, select the **"OpenEnv Bio (
|
| 66 |
|
| 67 |
```python
|
| 68 |
import torch
|
|
@@ -108,7 +118,6 @@ Then run:
|
|
| 108 |
|
| 109 |
```bash
|
| 110 |
uv sync --extra train
|
| 111 |
-
uv add ipykernel jupyter --extra train
|
| 112 |
```
|
| 113 |
|
| 114 |
For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
|
|
@@ -128,20 +137,20 @@ For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
|
|
| 128 |
On H100, use the quantized Unsloth entrypoints:
|
| 129 |
|
| 130 |
```bash
|
| 131 |
-
uv run python training_unsloth.py --dry-run
|
| 132 |
-
uv run python training_unsloth.py --model-id Qwen/Qwen3
|
| 133 |
uv run python run_agent_unsloth.py
|
| 134 |
```
|
| 135 |
|
| 136 |
-
The checked-in `inference.ipynb` notebook
|
| 137 |
|
| 138 |
## Running Training in a Jupyter Notebook
|
| 139 |
|
| 140 |
Example cell:
|
| 141 |
|
| 142 |
```python
|
| 143 |
-
# In a notebook with the OpenEnv Bio (
|
| 144 |
-
!uv run python training_unsloth.py --model-id Qwen/Qwen3
|
| 145 |
```
|
| 146 |
|
| 147 |
Or run interactively from Python:
|
|
@@ -150,8 +159,8 @@ Or run interactively from Python:
|
|
| 150 |
import subprocess
|
| 151 |
subprocess.run([
|
| 152 |
"uv", "run", "python", "training_unsloth.py",
|
| 153 |
-
"--model-id", "Qwen/Qwen3
|
| 154 |
-
"--output-dir", "training/grpo-unsloth-
|
| 155 |
], check=True)
|
| 156 |
```
|
| 157 |
|
|
@@ -159,16 +168,32 @@ subprocess.run([
|
|
| 159 |
|
| 160 |
| Component | Version / Notes |
|
| 161 |
|----------------|------------------------------------------------------|
|
| 162 |
-
| Python |
|
| 163 |
| uv | ≥ 0.5.3 (for PyTorch index support) |
|
| 164 |
| torch | ≥ 2.10.0 (cu128 or cu126 for H100) |
|
| 165 |
-
| transformers | ≥
|
| 166 |
| trl | ≥ 0.29.0 |
|
| 167 |
| accelerate | ≥ 1.13.0 |
|
| 168 |
| Jupyter | Optional, for notebook workflows |
|
| 169 |
|
| 170 |
## Troubleshooting
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
### `torch.cuda.is_available()` is False
|
| 173 |
|
| 174 |
- Confirm the Jupyter kernel is the one where you ran `uv sync` (the one with `ipykernel`).
|
|
@@ -193,6 +218,51 @@ export HF_HUB_DISABLE_SYMLINKS_WARNING=1
|
|
| 193 |
- Use a smaller model (e.g., `Qwen/Qwen3.5-0.8B`) for experiments.
|
| 194 |
- Keep `--disable-4bit` off unless you explicitly need wider weights.
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
## See Also
|
| 197 |
|
| 198 |
- Main [README.md](README.md) for project overview, APIs, and usage
|
|
|
|
| 4 |
|
| 5 |
## Prerequisites
|
| 6 |
|
| 7 |
+
- **Python** 3.10, 3.11, or **3.12** (3.12 recommended for H100; 3.13 is not supported—numba, vllm, and others require <3.13)
|
| 8 |
- **uv** – fast Python package manager ([install instructions](#installing-uv))
|
| 9 |
- **NVIDIA driver** ≥ 535.104.05 (usually pre-installed on H100 instances)
|
| 10 |
- **CUDA** – H100 uses CUDA 12.x; PyTorch wheels bundle the runtime, so a separate CUDA Toolkit is not required
|
|
|
|
| 38 |
|
| 39 |
### 2. Use uv's auto PyTorch backend
|
| 40 |
|
| 41 |
+
The project uses Python 3.12 (see `.python-version`). uv will create a 3.12 venv. For H100 (CUDA 12.x):
|
| 42 |
|
| 43 |
```bash
|
| 44 |
+
# Install everything: core + training (TRL, transformers, torch) + Jupyter
|
| 45 |
UV_TORCH_BACKEND=cu128 uv sync --extra train
|
| 46 |
|
| 47 |
+
# Add Unsloth for training_unsloth.py (skips trl downgrade; Unsloth works with TRL 0.29)
|
| 48 |
+
uv pip install unsloth unsloth_zoo --no-deps
|
| 49 |
+
|
| 50 |
+
# (ipykernel is included in --extra train)
|
| 51 |
```
|
| 52 |
|
| 53 |
If `UV_TORCH_BACKEND=cu128` fails (e.g., cu128 wheels not available yet), try:
|
|
|
|
| 59 |
### 3. Register the environment as a Jupyter kernel
|
| 60 |
|
| 61 |
```bash
|
| 62 |
+
uv run python -m ipykernel install --user --name openenv-bio-312 --display-name "OpenEnv Bio (Python 3.12)"
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Or run the helper script (from project root):
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
bash scripts/register_kernel_312.sh
|
| 69 |
```
|
| 70 |
|
| 71 |
+
Then select **"OpenEnv Bio (Python 3.12)"** in the notebook kernel picker.
|
| 72 |
+
|
| 73 |
### 4. Verify CUDA
|
| 74 |
|
| 75 |
+
In a new Jupyter notebook, select the **"OpenEnv Bio (Python 3.12)"** kernel and run:
|
| 76 |
|
| 77 |
```python
|
| 78 |
import torch
|
|
|
|
| 118 |
|
| 119 |
```bash
|
| 120 |
uv sync --extra train
|
|
|
|
| 121 |
```
|
| 122 |
|
| 123 |
For CUDA 12.6 instead of 12.8, use `cu126` in the index URL and source names.
|
|
|
|
| 137 |
On H100, use the quantized Unsloth entrypoints:
|
| 138 |
|
| 139 |
```bash
|
| 140 |
+
uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b --dry-run
|
| 141 |
+
uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b
|
| 142 |
uv run python run_agent_unsloth.py
|
| 143 |
```
|
| 144 |
|
| 145 |
+
The checked-in `inference.ipynb` notebook uses `training_unsloth.py` helpers with 4-bit loading. vLLM fast inference is disabled to avoid dependency conflicts.
|
| 146 |
|
| 147 |
## Running Training in a Jupyter Notebook
|
| 148 |
|
| 149 |
Example cell:
|
| 150 |
|
| 151 |
```python
|
| 152 |
+
# In a notebook with the OpenEnv Bio (Python 3.12) kernel
|
| 153 |
+
!uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b --dry-run
|
| 154 |
```
|
| 155 |
|
| 156 |
Or run interactively from Python:
|
|
|
|
| 159 |
import subprocess
|
| 160 |
subprocess.run([
|
| 161 |
"uv", "run", "python", "training_unsloth.py",
|
| 162 |
+
"--model-id", "Qwen/Qwen3-4B-Base",
|
| 163 |
+
"--output-dir", "training/grpo-unsloth-qwen3-4b",
|
| 164 |
], check=True)
|
| 165 |
```
|
| 166 |
|
|
|
|
| 168 |
|
| 169 |
| Component | Version / Notes |
|
| 170 |
|----------------|------------------------------------------------------|
|
| 171 |
+
| Python | 3.10–3.12 (3.12 recommended; 3.13 not supported) |
|
| 172 |
| uv | ≥ 0.5.3 (for PyTorch index support) |
|
| 173 |
| torch | ≥ 2.10.0 (cu128 or cu126 for H100) |
|
| 174 |
+
| transformers | ≥4.57 (with unsloth≥2025.10.14) |
|
| 175 |
| trl | ≥ 0.29.0 |
|
| 176 |
| accelerate | ≥ 1.13.0 |
|
| 177 |
| Jupyter | Optional, for notebook workflows |
|
| 178 |
|
| 179 |
## Troubleshooting
|
| 180 |
|
| 181 |
+
### `RuntimeError: Cannot install on Python version 3.13.x` or numba / setup.py errors
|
| 182 |
+
|
| 183 |
+
Python 3.13 is not supported (numba, vllm, and other deps require <3.13). Use Python 3.12:
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
# With uv: ensure Python 3.12 is available, then sync
|
| 187 |
+
uv python install 3.12
|
| 188 |
+
uv sync --extra train
|
| 189 |
+
|
| 190 |
+
# Or create venv explicitly with 3.12
|
| 191 |
+
uv venv --python 3.12
|
| 192 |
+
UV_TORCH_BACKEND=cu128 uv sync --extra train
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
The project's `.python-version` file pins 3.12; uv will use it when creating the venv.
|
| 196 |
+
|
| 197 |
### `torch.cuda.is_available()` is False
|
| 198 |
|
| 199 |
- Confirm the Jupyter kernel is the one where you ran `uv sync` (the one with `ipykernel`).
|
|
|
|
| 218 |
- Use a smaller model (e.g., `Qwen/Qwen3.5-0.8B`) for experiments.
|
| 219 |
- Keep `--disable-4bit` off unless you explicitly need wider weights.
|
| 220 |
|
| 221 |
+
### `ModuleNotFoundError: No module named 'vllm.lora.models'`
|
| 222 |
+
|
| 223 |
+
Unsloth's `unsloth_zoo` imports vLLM at load time and expects `vllm.lora.models`, which some vLLM versions don't have. Fix by installing a compatible vLLM:
|
| 224 |
+
|
| 225 |
+
```bash
|
| 226 |
+
pip install "vllm==0.8.2"
|
| 227 |
+
# or
|
| 228 |
+
pip install "vllm==0.7.3"
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
**Note:** vLLM 0.8.2 pins `torch==2.6.0`, which conflicts with this project's `torch>=2.10.0`. If you hit that conflict:
|
| 232 |
+
|
| 233 |
+
1. Use a **separate environment** with torch 2.6–2.8 + vllm 0.8.2 + unsloth.
|
| 234 |
+
2. Or use the non-Unsloth path (`training_script.py` / `train.ipynb`) which doesn't depend on vLLM.
|
| 235 |
+
|
| 236 |
+
### `KeyError: 'qwen3_5'` / Qwen3.5 not supported
|
| 237 |
+
|
| 238 |
+
Qwen3.5 requires transformers 5.x. With transformers 4.57, use **Qwen2.5** instead:
|
| 239 |
+
- `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`
|
| 240 |
+
- `unsloth/Qwen2.5-7B-Instruct-bnb-4bit`
|
| 241 |
+
- `Qwen/Qwen2.5-3B-Instruct`
|
| 242 |
+
|
| 243 |
+
### `NameError: name 'PreTrainedConfig' is not defined` / `check_model_inputs` ImportError
|
| 244 |
+
|
| 245 |
+
Use unsloth≥2025.10.14 (PreTrainedConfig fix) with transformers≥4.57 (check_model_inputs). Run `uv sync --extra train` to get compatible versions.
|
| 246 |
+
|
| 247 |
+
### `ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'`
|
| 248 |
+
|
| 249 |
+
unsloth_zoo expects TRL <0.20. The project pins `trl>=0.19.0,<0.20`. If you see this error, ensure you've run `uv sync --extra train` so the locked trl version is used. Alternatively, try:
|
| 250 |
+
|
| 251 |
+
```bash
|
| 252 |
+
pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
(A newer unsloth_zoo may fix this and allow TRL 0.20+.)
|
| 256 |
+
|
| 257 |
+
### Unsloth import order warning
|
| 258 |
+
|
| 259 |
+
If you see "Unsloth should be imported before trl, transformers, peft", ensure `training_unsloth` is imported before `training_script` in your notebook:
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
from training_unsloth import make_training_args, run_training # first
|
| 263 |
+
import training_script as base
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
## See Also
|
| 267 |
|
| 268 |
- Main [README.md](README.md) for project overview, APIs, and usage
|
README.md
CHANGED
|
@@ -384,6 +384,23 @@ uv run python training_unsloth.py --dry-run
|
|
| 384 |
uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B
|
| 385 |
```
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
Key arguments:
|
| 388 |
|
| 389 |
| Argument | Default | Description |
|
|
@@ -405,7 +422,7 @@ Key arguments:
|
|
| 405 |
|
| 406 |
By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
|
| 407 |
|
| 408 |
-
`training_unsloth.py` adds H100-oriented options such as `--max-seq-length`, `--disable-4bit`,
|
| 409 |
|
| 410 |
After training, the script saves plots to the output directory:
|
| 411 |
|
|
|
|
| 384 |
uv run python training_unsloth.py --model-id Qwen/Qwen3.5-4B
|
| 385 |
```
|
| 386 |
|
| 387 |
+
**Laptop / mid-range GPU (e.g. 12GB VRAM):** Use reduced batch size and sequence length to avoid OOM:
|
| 388 |
+
|
| 389 |
+
```bash
|
| 390 |
+
uv sync --extra train
|
| 391 |
+
uv pip install unsloth unsloth_zoo --no-deps # if using training_unsloth.py
|
| 392 |
+
uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b --dataset-episodes 12 --rollout-steps 6 --per-device-train-batch-size 1 --num-generations 2 --gradient-accumulation-steps 4 --max-seq-length 1024 --trust-remote-code
|
| 393 |
+
```
|
| 394 |
+
|
| 395 |
+
If you still hit OOM, try `--max-seq-length 768` or `--num-generations 1`.
|
| 396 |
+
|
| 397 |
+
**PyTorch CUDA:** Use the PyTorch index that matches your GPU. For older cards (RTX 20/30/40 series): `uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`. For **RTX 50 series (Blackwell, sm_120)** you need a CUDA 12.8 build:
|
| 398 |
+
|
| 399 |
+
```bash
|
| 400 |
+
uv pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
|
| 401 |
+
uv pip install triton-windows # required by Unsloth on Windows
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
Key arguments:
|
| 405 |
|
| 406 |
| Argument | Default | Description |
|
|
|
|
| 422 |
|
| 423 |
By default the reward function reconstructs prompt states locally so the prompt and reward stay aligned. Switch to a live server-backed reward loop with `--reward-backend remote --base-url http://localhost:8000`.
|
| 424 |
|
| 425 |
+
`training_unsloth.py` adds H100-oriented options such as `--max-seq-length`, `--disable-4bit`, and LoRA settings (`--lora-r`, `--lora-alpha`, `--lora-dropout`). vLLM fast inference is disabled to avoid dependency conflicts.
|
| 426 |
|
| 427 |
After training, the script saves plots to the output directory:
|
| 428 |
|
_dashboard_state.json
CHANGED
|
@@ -1,29 +1,29 @@
|
|
| 1 |
{
|
| 2 |
-
"timestamp":
|
| 3 |
"step": 20,
|
| 4 |
"episode_done": false,
|
| 5 |
-
"cumulative_reward": -1.
|
| 6 |
-
"gen_time_s":
|
| 7 |
-
"model_response_raw": "{\n\"action_type\": \"
|
| 8 |
"model_thinking": "",
|
| 9 |
"thinking_enabled": false,
|
| 10 |
"task": {
|
| 11 |
-
"problem_statement": "Validate candidate biomarker
|
| 12 |
"organism": "human",
|
| 13 |
-
"tissue": "
|
| 14 |
"modality": "scRNA-seq",
|
| 15 |
"conditions": [
|
| 16 |
"healthy",
|
| 17 |
-
"
|
| 18 |
],
|
| 19 |
-
"budget_limit":
|
| 20 |
-
"time_limit_days":
|
| 21 |
},
|
| 22 |
"resources": {
|
| 23 |
"budget_used": 28900.0,
|
| 24 |
-
"budget_remaining":
|
| 25 |
"time_used_days": 18.0,
|
| 26 |
-
"time_remaining_days":
|
| 27 |
"samples_consumed": 0,
|
| 28 |
"compute_hours_used": 12.0
|
| 29 |
},
|
|
@@ -32,9 +32,9 @@
|
|
| 32 |
"step_index": 1,
|
| 33 |
"action_type": "collect_sample",
|
| 34 |
"method": null,
|
| 35 |
-
"output_summary": "Collected 6 samples (quality=0.
|
| 36 |
"success": true,
|
| 37 |
-
"quality_score": 0.
|
| 38 |
"resource_cost": 5000.0,
|
| 39 |
"time_cost_days": 7.0
|
| 40 |
},
|
|
@@ -62,9 +62,9 @@
|
|
| 62 |
"step_index": 4,
|
| 63 |
"action_type": "prepare_library",
|
| 64 |
"method": null,
|
| 65 |
-
"output_summary": "Library prepared (complexity=0.
|
| 66 |
"success": true,
|
| 67 |
-
"quality_score": 0.
|
| 68 |
"resource_cost": 8000.0,
|
| 69 |
"time_cost_days": 3.0
|
| 70 |
},
|
|
@@ -72,42 +72,42 @@
|
|
| 72 |
"step_index": 5,
|
| 73 |
"action_type": "sequence_cells",
|
| 74 |
"method": null,
|
| 75 |
-
"output_summary": "Sequenced
|
| 76 |
"success": true,
|
| 77 |
-
"quality_score": 0.
|
| 78 |
"resource_cost": 15000.0,
|
| 79 |
"time_cost_days": 5.0
|
| 80 |
},
|
| 81 |
{
|
| 82 |
"step_index": 6,
|
| 83 |
-
"action_type": "filter_data",
|
| 84 |
-
"method": null,
|
| 85 |
-
"output_summary": "Action blocked: Cannot filter data before QC",
|
| 86 |
-
"success": false,
|
| 87 |
-
"quality_score": 1.0,
|
| 88 |
-
"resource_cost": 50.0,
|
| 89 |
-
"time_cost_days": 0.2
|
| 90 |
-
},
|
| 91 |
-
{
|
| 92 |
-
"step_index": 7,
|
| 93 |
"action_type": "run_qc",
|
| 94 |
"method": null,
|
| 95 |
"output_summary": "QC metrics computed",
|
| 96 |
"success": true,
|
| 97 |
-
"quality_score": 0.
|
| 98 |
"resource_cost": 100.0,
|
| 99 |
"time_cost_days": 0.5
|
| 100 |
},
|
| 101 |
{
|
| 102 |
-
"step_index":
|
| 103 |
"action_type": "filter_data",
|
| 104 |
"method": null,
|
| 105 |
-
"output_summary": "Filtered
|
| 106 |
"success": true,
|
| 107 |
-
"quality_score": 0.
|
| 108 |
"resource_cost": 50.0,
|
| 109 |
"time_cost_days": 0.2
|
| 110 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
{
|
| 112 |
"step_index": 9,
|
| 113 |
"action_type": "normalize_data",
|
|
@@ -142,14 +142,24 @@
|
|
| 142 |
"step_index": 12,
|
| 143 |
"action_type": "cluster_cells",
|
| 144 |
"method": null,
|
| 145 |
-
"output_summary": "Found
|
| 146 |
"success": true,
|
| 147 |
-
"quality_score": 0.
|
| 148 |
"resource_cost": 100.0,
|
| 149 |
"time_cost_days": 0.5
|
| 150 |
},
|
| 151 |
{
|
| 152 |
"step_index": 13,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
"action_type": "design_followup_experiment",
|
| 154 |
"method": null,
|
| 155 |
"output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
|
|
@@ -159,7 +169,7 @@
|
|
| 159 |
"time_cost_days": 0.5
|
| 160 |
},
|
| 161 |
{
|
| 162 |
-
"step_index":
|
| 163 |
"action_type": "synthesize_conclusion",
|
| 164 |
"method": null,
|
| 165 |
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
|
@@ -168,16 +178,6 @@
|
|
| 168 |
"resource_cost": 0.0,
|
| 169 |
"time_cost_days": 0.5
|
| 170 |
},
|
| 171 |
-
{
|
| 172 |
-
"step_index": 15,
|
| 173 |
-
"action_type": "marker_selection",
|
| 174 |
-
"method": null,
|
| 175 |
-
"output_summary": "Action blocked: Cannot select markers without DE results",
|
| 176 |
-
"success": false,
|
| 177 |
-
"quality_score": 1.0,
|
| 178 |
-
"resource_cost": 100.0,
|
| 179 |
-
"time_cost_days": 0.5
|
| 180 |
-
},
|
| 181 |
{
|
| 182 |
"step_index": 16,
|
| 183 |
"action_type": "synthesize_conclusion",
|
|
@@ -222,8 +222,8 @@
|
|
| 222 |
"Cannot synthesise conclusion before inferring pathways or mechanisms"
|
| 223 |
],
|
| 224 |
"uncertainty_summary": {
|
| 225 |
-
"avg_uncertainty": 0.
|
| 226 |
-
"avg_quality": 0.
|
| 227 |
},
|
| 228 |
"reward_breakdown": {
|
| 229 |
"validity": -1.0,
|
|
@@ -249,72 +249,92 @@
|
|
| 249 |
"latent": {
|
| 250 |
"cell_populations": [
|
| 251 |
{
|
| 252 |
-
"name": "
|
| 253 |
-
"proportion": 0.
|
| 254 |
"marker_genes": [
|
| 255 |
-
"
|
| 256 |
-
"
|
| 257 |
-
"
|
| 258 |
],
|
| 259 |
-
"state": "
|
| 260 |
},
|
| 261 |
{
|
| 262 |
-
"name": "
|
| 263 |
-
"proportion": 0.
|
| 264 |
"marker_genes": [
|
| 265 |
-
"
|
| 266 |
-
"
|
| 267 |
-
"
|
| 268 |
],
|
| 269 |
"state": "normal"
|
| 270 |
},
|
| 271 |
{
|
| 272 |
-
"name": "
|
| 273 |
-
"proportion": 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
"marker_genes": [
|
| 275 |
-
"
|
| 276 |
-
"
|
| 277 |
-
"
|
| 278 |
],
|
| 279 |
-
"state": "
|
| 280 |
},
|
| 281 |
{
|
| 282 |
"name": "endothelial",
|
| 283 |
-
"proportion": 0.
|
| 284 |
"marker_genes": [
|
| 285 |
-
"
|
| 286 |
-
"
|
| 287 |
-
"
|
| 288 |
],
|
| 289 |
"state": "quiescent"
|
| 290 |
}
|
| 291 |
],
|
| 292 |
"true_markers": [
|
| 293 |
-
"
|
| 294 |
-
"
|
| 295 |
-
"
|
|
|
|
| 296 |
],
|
| 297 |
"causal_mechanisms": [
|
| 298 |
-
"
|
| 299 |
-
"
|
| 300 |
-
"reactive astrogliosis amplifying neuroinflammation"
|
| 301 |
],
|
| 302 |
"true_pathways": {
|
| 303 |
-
"
|
| 304 |
-
"
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
},
|
| 309 |
-
"true_de_genes_count":
|
| 310 |
-
"true_regulatory_network_size":
|
| 311 |
"confounders": {},
|
| 312 |
-
"n_true_cells":
|
| 313 |
"technical": {
|
| 314 |
-
"ambient_rna_fraction": 0.
|
| 315 |
-
"doublet_rate": 0.
|
| 316 |
-
"dropout_rate": 0.
|
| 317 |
-
"sample_quality": 0.
|
| 318 |
"library_complexity": 0.8,
|
| 319 |
"capture_efficiency": 0.6
|
| 320 |
},
|
|
@@ -339,9 +359,9 @@
|
|
| 339 |
"followup_designed": true,
|
| 340 |
"subagent_review_requested": false,
|
| 341 |
"conclusion_reached": false,
|
| 342 |
-
"n_cells_sequenced":
|
| 343 |
-
"n_cells_after_filter":
|
| 344 |
-
"n_clusters_found": "
|
| 345 |
"n_de_genes_found": null,
|
| 346 |
"n_markers_found": null
|
| 347 |
},
|
|
|
|
| 1 |
{
|
| 2 |
+
"timestamp": 1772986999.4539518,
|
| 3 |
"step": 20,
|
| 4 |
"episode_done": false,
|
| 5 |
+
"cumulative_reward": -1.2923024829860532,
|
| 6 |
+
"gen_time_s": 17.14,
|
| 7 |
+
"model_response_raw": "{\n\"action_type\": \"design_followup_experiment\",\n\"method\": \"\",\n\"parameters\": {\n\"candidate_target_genes\": [],\n\"suggested_wet_lab_targets\": []\n},\n\"Justification\": \"The previous attempt failed due to incomplete genomic profiling required for robust marker identification via CRISPR/Cas9 editing followed by functional readout. Without identifying specific high-confidence target genes within the identified clusters through comprehensive transcriptomic characterization first, designing a focused knockdown screen lacks the necessary molecular basis.\",\n\"Confidence\": 0.85\n}",
|
| 8 |
"model_thinking": "",
|
| 9 |
"thinking_enabled": false,
|
| 10 |
"task": {
|
| 11 |
+
"problem_statement": "Validate candidate biomarker ERBB2 for breast cancer in breast tissue using single-cell RNA sequencing.",
|
| 12 |
"organism": "human",
|
| 13 |
+
"tissue": "breast",
|
| 14 |
"modality": "scRNA-seq",
|
| 15 |
"conditions": [
|
| 16 |
"healthy",
|
| 17 |
+
"breast_cancer"
|
| 18 |
],
|
| 19 |
+
"budget_limit": 89988.7088221059,
|
| 20 |
+
"time_limit_days": 160.54122628462622
|
| 21 |
},
|
| 22 |
"resources": {
|
| 23 |
"budget_used": 28900.0,
|
| 24 |
+
"budget_remaining": 61088.71,
|
| 25 |
"time_used_days": 18.0,
|
| 26 |
+
"time_remaining_days": 142.5,
|
| 27 |
"samples_consumed": 0,
|
| 28 |
"compute_hours_used": 12.0
|
| 29 |
},
|
|
|
|
| 32 |
"step_index": 1,
|
| 33 |
"action_type": "collect_sample",
|
| 34 |
"method": null,
|
| 35 |
+
"output_summary": "Collected 6 samples (quality=0.42)",
|
| 36 |
"success": true,
|
| 37 |
+
"quality_score": 0.422,
|
| 38 |
"resource_cost": 5000.0,
|
| 39 |
"time_cost_days": 7.0
|
| 40 |
},
|
|
|
|
| 62 |
"step_index": 4,
|
| 63 |
"action_type": "prepare_library",
|
| 64 |
"method": null,
|
| 65 |
+
"output_summary": "Library prepared (complexity=0.59)",
|
| 66 |
"success": true,
|
| 67 |
+
"quality_score": 0.595,
|
| 68 |
"resource_cost": 8000.0,
|
| 69 |
"time_cost_days": 3.0
|
| 70 |
},
|
|
|
|
| 72 |
"step_index": 5,
|
| 73 |
"action_type": "sequence_cells",
|
| 74 |
"method": null,
|
| 75 |
+
"output_summary": "Sequenced 8578 cells, 9244 genes detected, median UMI=3090",
|
| 76 |
"success": true,
|
| 77 |
+
"quality_score": 0.359,
|
| 78 |
"resource_cost": 15000.0,
|
| 79 |
"time_cost_days": 5.0
|
| 80 |
},
|
| 81 |
{
|
| 82 |
"step_index": 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
"action_type": "run_qc",
|
| 84 |
"method": null,
|
| 85 |
"output_summary": "QC metrics computed",
|
| 86 |
"success": true,
|
| 87 |
+
"quality_score": 0.821,
|
| 88 |
"resource_cost": 100.0,
|
| 89 |
"time_cost_days": 0.5
|
| 90 |
},
|
| 91 |
{
|
| 92 |
+
"step_index": 7,
|
| 93 |
"action_type": "filter_data",
|
| 94 |
"method": null,
|
| 95 |
+
"output_summary": "Filtered 8578 \u2192 7050 cells (82% retained)",
|
| 96 |
"success": true,
|
| 97 |
+
"quality_score": 0.822,
|
| 98 |
"resource_cost": 50.0,
|
| 99 |
"time_cost_days": 0.2
|
| 100 |
},
|
| 101 |
+
{
|
| 102 |
+
"step_index": 8,
|
| 103 |
+
"action_type": "cluster_cells",
|
| 104 |
+
"method": null,
|
| 105 |
+
"output_summary": "Action blocked: Cannot cluster before normalisation",
|
| 106 |
+
"success": false,
|
| 107 |
+
"quality_score": 1.0,
|
| 108 |
+
"resource_cost": 100.0,
|
| 109 |
+
"time_cost_days": 0.5
|
| 110 |
+
},
|
| 111 |
{
|
| 112 |
"step_index": 9,
|
| 113 |
"action_type": "normalize_data",
|
|
|
|
| 142 |
"step_index": 12,
|
| 143 |
"action_type": "cluster_cells",
|
| 144 |
"method": null,
|
| 145 |
+
"output_summary": "Found 7 clusters",
|
| 146 |
"success": true,
|
| 147 |
+
"quality_score": 0.741,
|
| 148 |
"resource_cost": 100.0,
|
| 149 |
"time_cost_days": 0.5
|
| 150 |
},
|
| 151 |
{
|
| 152 |
"step_index": 13,
|
| 153 |
+
"action_type": "marker_selection",
|
| 154 |
+
"method": null,
|
| 155 |
+
"output_summary": "Action blocked: Cannot select markers without DE results",
|
| 156 |
+
"success": false,
|
| 157 |
+
"quality_score": 1.0,
|
| 158 |
+
"resource_cost": 100.0,
|
| 159 |
+
"time_cost_days": 0.5
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"step_index": 14,
|
| 163 |
"action_type": "design_followup_experiment",
|
| 164 |
"method": null,
|
| 165 |
"output_summary": "Follow-up experiment design proposed (evidence_signals=1)",
|
|
|
|
| 169 |
"time_cost_days": 0.5
|
| 170 |
},
|
| 171 |
{
|
| 172 |
+
"step_index": 15,
|
| 173 |
"action_type": "synthesize_conclusion",
|
| 174 |
"method": null,
|
| 175 |
"output_summary": "Action blocked: Cannot synthesise conclusion before discovering or validating markers; Cannot synthesise conclusion befo",
|
|
|
|
| 178 |
"resource_cost": 0.0,
|
| 179 |
"time_cost_days": 0.5
|
| 180 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
{
|
| 182 |
"step_index": 16,
|
| 183 |
"action_type": "synthesize_conclusion",
|
|
|
|
| 222 |
"Cannot synthesise conclusion before inferring pathways or mechanisms"
|
| 223 |
],
|
| 224 |
"uncertainty_summary": {
|
| 225 |
+
"avg_uncertainty": 0.177,
|
| 226 |
+
"avg_quality": 0.804
|
| 227 |
},
|
| 228 |
"reward_breakdown": {
|
| 229 |
"validity": -1.0,
|
|
|
|
| 249 |
"latent": {
|
| 250 |
"cell_populations": [
|
| 251 |
{
|
| 252 |
+
"name": "luminal_epithelial",
|
| 253 |
+
"proportion": 0.433,
|
| 254 |
"marker_genes": [
|
| 255 |
+
"KRT8",
|
| 256 |
+
"KRT18",
|
| 257 |
+
"EPCAM"
|
| 258 |
],
|
| 259 |
+
"state": "normal"
|
| 260 |
},
|
| 261 |
{
|
| 262 |
+
"name": "basal_epithelial",
|
| 263 |
+
"proportion": 0.157,
|
| 264 |
"marker_genes": [
|
| 265 |
+
"KRT14",
|
| 266 |
+
"KRT5",
|
| 267 |
+
"TP63"
|
| 268 |
],
|
| 269 |
"state": "normal"
|
| 270 |
},
|
| 271 |
{
|
| 272 |
+
"name": "fibroblast",
|
| 273 |
+
"proportion": 0.119,
|
| 274 |
+
"marker_genes": [
|
| 275 |
+
"COL1A1",
|
| 276 |
+
"COL3A1",
|
| 277 |
+
"FAP"
|
| 278 |
+
],
|
| 279 |
+
"state": "quiescent"
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"name": "T_cell",
|
| 283 |
+
"proportion": 0.105,
|
| 284 |
+
"marker_genes": [
|
| 285 |
+
"CD3D",
|
| 286 |
+
"CD3E",
|
| 287 |
+
"CD8A"
|
| 288 |
+
],
|
| 289 |
+
"state": "activated"
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"name": "macrophage",
|
| 293 |
+
"proportion": 0.096,
|
| 294 |
"marker_genes": [
|
| 295 |
+
"CD68",
|
| 296 |
+
"CD163",
|
| 297 |
+
"CSF1R"
|
| 298 |
],
|
| 299 |
+
"state": "inflammatory"
|
| 300 |
},
|
| 301 |
{
|
| 302 |
"name": "endothelial",
|
| 303 |
+
"proportion": 0.09,
|
| 304 |
"marker_genes": [
|
| 305 |
+
"PECAM1",
|
| 306 |
+
"VWF",
|
| 307 |
+
"CDH5"
|
| 308 |
],
|
| 309 |
"state": "quiescent"
|
| 310 |
}
|
| 311 |
],
|
| 312 |
"true_markers": [
|
| 313 |
+
"ERBB2",
|
| 314 |
+
"MKI67",
|
| 315 |
+
"CD274",
|
| 316 |
+
"VIM"
|
| 317 |
],
|
| 318 |
"causal_mechanisms": [
|
| 319 |
+
"ERBB2-driven proliferative signalling",
|
| 320 |
+
"immune evasion via PD-L1 upregulation"
|
|
|
|
| 321 |
],
|
| 322 |
"true_pathways": {
|
| 323 |
+
"cell_cycle": 0.889,
|
| 324 |
+
"PI3K_AKT_signalling": 0.803,
|
| 325 |
+
"EMT": 0.757,
|
| 326 |
+
"immune_checkpoint": 0.579,
|
| 327 |
+
"estrogen_signalling": 0.644
|
| 328 |
},
|
| 329 |
+
"true_de_genes_count": 9,
|
| 330 |
+
"true_regulatory_network_size": 9,
|
| 331 |
"confounders": {},
|
| 332 |
+
"n_true_cells": 14493,
|
| 333 |
"technical": {
|
| 334 |
+
"ambient_rna_fraction": 0.05700028722692205,
|
| 335 |
+
"doublet_rate": 0.0716382392677839,
|
| 336 |
+
"dropout_rate": 0.1822398381996976,
|
| 337 |
+
"sample_quality": 0.7058144963381642,
|
| 338 |
"library_complexity": 0.8,
|
| 339 |
"capture_efficiency": 0.6
|
| 340 |
},
|
|
|
|
| 359 |
"followup_designed": true,
|
| 360 |
"subagent_review_requested": false,
|
| 361 |
"conclusion_reached": false,
|
| 362 |
+
"n_cells_sequenced": 8578,
|
| 363 |
+
"n_cells_after_filter": 7050,
|
| 364 |
+
"n_clusters_found": "7",
|
| 365 |
"n_de_genes_found": null,
|
| 366 |
"n_markers_found": null
|
| 367 |
},
|
colab_train_unsloth.ipynb
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Unsloth training on Colab\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Minimal setup: clone repo → install deps → run GRPO training with Unsloth (Qwen3-4B, 4-bit + LoRA).\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Runtime**: Enable a GPU (e.g. T4) in Colab: Runtime → Change runtime type → GPU."
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": null,
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"# 1. Clone repo (set branch/tag if needed)\n",
|
| 21 |
+
"REPO_URL = \"https://github.com/mhtruong1031/OpenENV-Hackathon.git\" # or your fork\n",
|
| 22 |
+
"REPO_DIR = \"OpenENV-Hackathon\"\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"!git clone --depth 1 {REPO_URL} {REPO_DIR}\n",
|
| 25 |
+
"%cd {REPO_DIR}"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"# 2. Install requirements: project + train extras + Unsloth (no-deps to keep trl>=0.29)\n",
|
| 35 |
+
"!pip install -q -e \".[train]\"\n",
|
| 36 |
+
"!pip install -q unsloth unsloth_zoo --no-deps\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Optional: reward backends\n",
|
| 39 |
+
"!pip install -q sentence-transformers gseapy 2>/dev/null || true"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"outputs": [],
|
| 47 |
+
"source": [
|
| 48 |
+
"# 3. Unsloth must be imported before trl/transformers/peft\n",
|
| 49 |
+
"import unsloth # noqa: F401\n",
|
| 50 |
+
"import torch\n",
|
| 51 |
+
"from pathlib import Path\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"from training_unsloth import make_training_args, run_training\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"print(\"CUDA:\", torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"\")\n",
|
| 56 |
+
"Path(\"artifacts\").mkdir(exist_ok=True)"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"# 4. Training config (small run for Colab T4)\n",
|
| 66 |
+
"args = make_training_args(\n",
|
| 67 |
+
" model_id=\"Qwen/Qwen3-4B-Base\",\n",
|
| 68 |
+
" output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n",
|
| 69 |
+
" dataset_episodes=16,\n",
|
| 70 |
+
" rollout_steps=10,\n",
|
| 71 |
+
" collection_policy=\"heuristic\",\n",
|
| 72 |
+
" reward_backend=\"local\",\n",
|
| 73 |
+
" domain_randomise=True,\n",
|
| 74 |
+
" num_generations=4,\n",
|
| 75 |
+
" max_completion_length=160,\n",
|
| 76 |
+
" max_prompt_length=1280,\n",
|
| 77 |
+
" max_seq_length=2048,\n",
|
| 78 |
+
" per_device_train_batch_size=2,\n",
|
| 79 |
+
" gradient_accumulation_steps=4,\n",
|
| 80 |
+
" learning_rate=5e-6,\n",
|
| 81 |
+
" num_train_epochs=1.0,\n",
|
| 82 |
+
" logging_steps=1,\n",
|
| 83 |
+
" save_steps=25,\n",
|
| 84 |
+
" trust_remote_code=True,\n",
|
| 85 |
+
" dry_run=False,\n",
|
| 86 |
+
" seed=42,\n",
|
| 87 |
+
")\n",
|
| 88 |
+
"args"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"# 5. Run training\n",
|
| 98 |
+
"result = run_training(args)\n",
|
| 99 |
+
"print(\"Plots:\", result[\"plot_paths\"])"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"cell_type": "code",
|
| 104 |
+
"execution_count": null,
|
| 105 |
+
"metadata": {},
|
| 106 |
+
"outputs": [],
|
| 107 |
+
"source": [
|
| 108 |
+
"# 6. (Optional) Show loss curves\n",
|
| 109 |
+
"from IPython.display import Image, display\n",
|
| 110 |
+
"for name, path in result[\"plot_paths\"].items():\n",
|
| 111 |
+
" display(Image(filename=path))"
|
| 112 |
+
]
|
| 113 |
+
}
|
| 114 |
+
],
|
| 115 |
+
"metadata": {
|
| 116 |
+
"kernelspec": {
|
| 117 |
+
"display_name": "Python 3",
|
| 118 |
+
"language": "python",
|
| 119 |
+
"name": "python3"
|
| 120 |
+
},
|
| 121 |
+
"language_info": {
|
| 122 |
+
"name": "python",
|
| 123 |
+
"version": "3.10.0"
|
| 124 |
+
}
|
| 125 |
+
},
|
| 126 |
+
"nbformat": 4,
|
| 127 |
+
"nbformat_minor": 4
|
| 128 |
+
}
|
inference.ipynb
CHANGED
|
@@ -2,27 +2,30 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
|
|
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
"# Self-Driving Lab Inference on H100 With Unsloth\n",
|
| 8 |
"\n",
|
| 9 |
"This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100."
|
| 10 |
-
]
|
| 11 |
-
"id": "a9d34036"
|
| 12 |
},
|
| 13 |
{
|
| 14 |
"cell_type": "code",
|
|
|
|
|
|
|
| 15 |
"metadata": {},
|
|
|
|
| 16 |
"source": [
|
| 17 |
"%pip install -q -U torch transformers unsloth"
|
| 18 |
-
]
|
| 19 |
-
"execution_count": null,
|
| 20 |
-
"outputs": [],
|
| 21 |
-
"id": "20b36e01"
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"cell_type": "code",
|
|
|
|
|
|
|
| 25 |
"metadata": {},
|
|
|
|
| 26 |
"source": [
|
| 27 |
"import json\n",
|
| 28 |
"\n",
|
|
@@ -36,14 +39,14 @@
|
|
| 36 |
"if torch.cuda.is_available():\n",
|
| 37 |
" print(\"GPU:\", torch.cuda.get_device_name(0))\n",
|
| 38 |
" print(\"bf16 supported:\", torch.cuda.is_bf16_supported())"
|
| 39 |
-
]
|
| 40 |
-
"execution_count": null,
|
| 41 |
-
"outputs": [],
|
| 42 |
-
"id": "bcf24a2e"
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"cell_type": "code",
|
|
|
|
|
|
|
| 46 |
"metadata": {},
|
|
|
|
| 47 |
"source": [
|
| 48 |
"MODEL_PATH = \"artifacts/grpo-unsloth-output\" # or a Hugging Face repo / base model id\n",
|
| 49 |
"SCENARIO_NAME = \"cardiac_disease_de\"\n",
|
|
@@ -54,21 +57,20 @@
|
|
| 54 |
" trust_remote_code=True,\n",
|
| 55 |
" max_seq_length=2048,\n",
|
| 56 |
" load_in_4bit=True,\n",
|
| 57 |
-
" fast_inference=True,\n",
|
| 58 |
" prepare_for_inference=True,\n",
|
| 59 |
")\n",
|
| 60 |
"\n",
|
| 61 |
"env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n",
|
| 62 |
"obs = env.reset(seed=SEED)\n",
|
| 63 |
"print(format_observation(obs)[:3000])"
|
| 64 |
-
]
|
| 65 |
-
"execution_count": null,
|
| 66 |
-
"outputs": [],
|
| 67 |
-
"id": "c54f2cfd"
|
| 68 |
},
|
| 69 |
{
|
| 70 |
"cell_type": "code",
|
|
|
|
|
|
|
| 71 |
"metadata": {},
|
|
|
|
| 72 |
"source": [
|
| 73 |
"result = generate_action_with_model(\n",
|
| 74 |
" model,\n",
|
|
@@ -84,14 +86,14 @@
|
|
| 84 |
"print(result[\"response_text\"])\n",
|
| 85 |
"print(\"\\nParsed action:\\n\")\n",
|
| 86 |
"result[\"action\"].model_dump() if result[\"action\"] is not None else None"
|
| 87 |
-
]
|
| 88 |
-
"execution_count": null,
|
| 89 |
-
"outputs": [],
|
| 90 |
-
"id": "f9b25208"
|
| 91 |
},
|
| 92 |
{
|
| 93 |
"cell_type": "code",
|
|
|
|
|
|
|
| 94 |
"metadata": {},
|
|
|
|
| 95 |
"source": [
|
| 96 |
"if result[\"action\"] is not None:\n",
|
| 97 |
" next_obs = env.step(result[\"action\"])\n",
|
|
@@ -106,14 +108,14 @@
|
|
| 106 |
" print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n",
|
| 107 |
"else:\n",
|
| 108 |
" print(\"Model output did not parse into an ExperimentAction.\")"
|
| 109 |
-
]
|
| 110 |
-
"execution_count": null,
|
| 111 |
-
"outputs": [],
|
| 112 |
-
"id": "c2408f52"
|
| 113 |
},
|
| 114 |
{
|
| 115 |
"cell_type": "code",
|
|
|
|
|
|
|
| 116 |
"metadata": {},
|
|
|
|
| 117 |
"source": [
|
| 118 |
"# Optional short closed-loop rollout.\n",
|
| 119 |
"obs = env.reset(seed=7)\n",
|
|
@@ -145,10 +147,7 @@
|
|
| 145 |
" break\n",
|
| 146 |
"\n",
|
| 147 |
"trajectory"
|
| 148 |
-
]
|
| 149 |
-
"execution_count": null,
|
| 150 |
-
"outputs": [],
|
| 151 |
-
"id": "8af34f32"
|
| 152 |
}
|
| 153 |
],
|
| 154 |
"metadata": {
|
|
@@ -163,4 +162,4 @@
|
|
| 163 |
},
|
| 164 |
"nbformat": 4,
|
| 165 |
"nbformat_minor": 5
|
| 166 |
-
}
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
+
"id": "a9d34036",
|
| 6 |
"metadata": {},
|
| 7 |
"source": [
|
| 8 |
"# Self-Driving Lab Inference on H100 With Unsloth\n",
|
| 9 |
"\n",
|
| 10 |
"This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100."
|
| 11 |
+
]
|
|
|
|
| 12 |
},
|
| 13 |
{
|
| 14 |
"cell_type": "code",
|
| 15 |
+
"execution_count": null,
|
| 16 |
+
"id": "20b36e01",
|
| 17 |
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
"source": [
|
| 20 |
"%pip install -q -U torch transformers unsloth"
|
| 21 |
+
]
|
|
|
|
|
|
|
|
|
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"id": "bcf24a2e",
|
| 27 |
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
"source": [
|
| 30 |
"import json\n",
|
| 31 |
"\n",
|
|
|
|
| 39 |
"if torch.cuda.is_available():\n",
|
| 40 |
" print(\"GPU:\", torch.cuda.get_device_name(0))\n",
|
| 41 |
" print(\"bf16 supported:\", torch.cuda.is_bf16_supported())"
|
| 42 |
+
]
|
|
|
|
|
|
|
|
|
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"id": "c54f2cfd",
|
| 48 |
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
"source": [
|
| 51 |
"MODEL_PATH = \"artifacts/grpo-unsloth-output\" # or a Hugging Face repo / base model id\n",
|
| 52 |
"SCENARIO_NAME = \"cardiac_disease_de\"\n",
|
|
|
|
| 57 |
" trust_remote_code=True,\n",
|
| 58 |
" max_seq_length=2048,\n",
|
| 59 |
" load_in_4bit=True,\n",
|
|
|
|
| 60 |
" prepare_for_inference=True,\n",
|
| 61 |
")\n",
|
| 62 |
"\n",
|
| 63 |
"env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n",
|
| 64 |
"obs = env.reset(seed=SEED)\n",
|
| 65 |
"print(format_observation(obs)[:3000])"
|
| 66 |
+
]
|
|
|
|
|
|
|
|
|
|
| 67 |
},
|
| 68 |
{
|
| 69 |
"cell_type": "code",
|
| 70 |
+
"execution_count": null,
|
| 71 |
+
"id": "f9b25208",
|
| 72 |
"metadata": {},
|
| 73 |
+
"outputs": [],
|
| 74 |
"source": [
|
| 75 |
"result = generate_action_with_model(\n",
|
| 76 |
" model,\n",
|
|
|
|
| 86 |
"print(result[\"response_text\"])\n",
|
| 87 |
"print(\"\\nParsed action:\\n\")\n",
|
| 88 |
"result[\"action\"].model_dump() if result[\"action\"] is not None else None"
|
| 89 |
+
]
|
|
|
|
|
|
|
|
|
|
| 90 |
},
|
| 91 |
{
|
| 92 |
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"id": "c2408f52",
|
| 95 |
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
"source": [
|
| 98 |
"if result[\"action\"] is not None:\n",
|
| 99 |
" next_obs = env.step(result[\"action\"])\n",
|
|
|
|
| 108 |
" print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n",
|
| 109 |
"else:\n",
|
| 110 |
" print(\"Model output did not parse into an ExperimentAction.\")"
|
| 111 |
+
]
|
|
|
|
|
|
|
|
|
|
| 112 |
},
|
| 113 |
{
|
| 114 |
"cell_type": "code",
|
| 115 |
+
"execution_count": null,
|
| 116 |
+
"id": "8af34f32",
|
| 117 |
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
"source": [
|
| 120 |
"# Optional short closed-loop rollout.\n",
|
| 121 |
"obs = env.reset(seed=7)\n",
|
|
|
|
| 147 |
" break\n",
|
| 148 |
"\n",
|
| 149 |
"trajectory"
|
| 150 |
+
]
|
|
|
|
|
|
|
|
|
|
| 151 |
}
|
| 152 |
],
|
| 153 |
"metadata": {
|
|
|
|
| 162 |
},
|
| 163 |
"nbformat": 4,
|
| 164 |
"nbformat_minor": 5
|
| 165 |
+
}
|
pyproject.toml
CHANGED
|
@@ -12,7 +12,7 @@ build-backend = "setuptools.build_meta"
|
|
| 12 |
name = "openenv-bio-experiment"
|
| 13 |
version = "0.1.0"
|
| 14 |
description = "RL environment for biological experiment pipeline planning"
|
| 15 |
-
requires-python = ">=3.10"
|
| 16 |
dependencies = [
|
| 17 |
"openenv-core[core]>=0.2.0",
|
| 18 |
"numpy>=1.24.0",
|
|
@@ -34,12 +34,15 @@ train = [
|
|
| 34 |
"accelerate>=1.13.0",
|
| 35 |
"bitsandbytes>=0.45.0",
|
| 36 |
"datasets>=4.6.1",
|
|
|
|
| 37 |
"matplotlib>=3.10.8",
|
| 38 |
"peft>=0.15.0",
|
| 39 |
"torch>=2.10.0",
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
|
|
|
| 43 |
]
|
| 44 |
|
| 45 |
[project.scripts]
|
|
|
|
| 12 |
name = "openenv-bio-experiment"
|
| 13 |
version = "0.1.0"
|
| 14 |
description = "RL environment for biological experiment pipeline planning"
|
| 15 |
+
requires-python = ">=3.10,<3.13"
|
| 16 |
dependencies = [
|
| 17 |
"openenv-core[core]>=0.2.0",
|
| 18 |
"numpy>=1.24.0",
|
|
|
|
| 34 |
"accelerate>=1.13.0",
|
| 35 |
"bitsandbytes>=0.45.0",
|
| 36 |
"datasets>=4.6.1",
|
| 37 |
+
"ipykernel>=6.0.0",
|
| 38 |
"matplotlib>=3.10.8",
|
| 39 |
"peft>=0.15.0",
|
| 40 |
"torch>=2.10.0",
|
| 41 |
+
"torchvision>=0.20.0", # required by transformers for Qwen3.5 (image_utils)
|
| 42 |
+
"transformers>=5.3.0", # 5.3+ required for Qwen3.5 (qwen3_5 model type)
|
| 43 |
+
"llm-blender>=0.0.2", # required by trl GRPOTrainer judges
|
| 44 |
+
"mergekit>=0.1.0", # required by trl GRPOTrainer/callbacks
|
| 45 |
+
"trl>=0.29.0", # GRPOTrainer; 0.29+ compatible with transformers 5.3
|
| 46 |
]
|
| 47 |
|
| 48 |
[project.scripts]
|
requirements-unsloth.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Unsloth path: install after uv sync --extra train
|
| 2 |
+
# Unsloth pins trl<=0.24; this project uses trl>=0.29 for training_script.py.
|
| 3 |
+
# Use --no-deps to avoid downgrading trl (Unsloth works with TRL 0.29 in practice).
|
| 4 |
+
#
|
| 5 |
+
# Install:
|
| 6 |
+
# uv sync --extra train
|
| 7 |
+
# uv pip install unsloth unsloth_zoo --no-deps
|
| 8 |
+
#
|
| 9 |
+
# Run training with Qwen3-4B-Base:
|
| 10 |
+
# uv run python training_unsloth.py --model-id Qwen/Qwen3-4B-Base --output-dir training/grpo-unsloth-qwen3-4b
|
| 11 |
+
#
|
| 12 |
+
unsloth>=2025.10.14
|
| 13 |
+
unsloth_zoo
|
run_agent_unsloth.py
CHANGED
|
@@ -40,7 +40,7 @@ LOAD_IN_4BIT = (
|
|
| 40 |
not in {"0", "false", "off"}
|
| 41 |
)
|
| 42 |
FAST_INFERENCE = (
|
| 43 |
-
os.getenv("RUN_AGENT_UNSLOTH_FAST_INFERENCE", "
|
| 44 |
not in {"0", "false", "off"}
|
| 45 |
)
|
| 46 |
|
|
|
|
| 40 |
not in {"0", "false", "off"}
|
| 41 |
)
|
| 42 |
FAST_INFERENCE = (
|
| 43 |
+
os.getenv("RUN_AGENT_UNSLOTH_FAST_INFERENCE", "0").strip().lower()
|
| 44 |
not in {"0", "false", "off"}
|
| 45 |
)
|
| 46 |
|
scripts/register_kernel_312.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Register a Python 3.12 Jupyter kernel from the project's uv venv.
|
| 3 |
+
# Run from project root: bash scripts/register_kernel_312.sh
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
cd "$(dirname "$0")/.."
|
| 7 |
+
|
| 8 |
+
echo "Ensuring Python 3.12 venv..."
|
| 9 |
+
uv python install 3.12
|
| 10 |
+
UV_TORCH_BACKEND="${UV_TORCH_BACKEND:-cu128}" uv sync --extra train
|
| 11 |
+
|
| 12 |
+
echo "Registering Jupyter kernel 'OpenEnv Bio (Python 3.12)'..."
|
| 13 |
+
uv run python -m ipykernel install --user \
|
| 14 |
+
--name openenv-bio-312 \
|
| 15 |
+
--display-name "OpenEnv Bio (Python 3.12)"
|
| 16 |
+
|
| 17 |
+
echo "Done. Select 'OpenEnv Bio (Python 3.12)' in the notebook kernel picker."
|
train.ipynb
CHANGED
|
@@ -7,8 +7,9 @@
|
|
| 7 |
"source": [
|
| 8 |
"# Train A Self-Driving Lab Policy on H100\n",
|
| 9 |
"\n",
|
| 10 |
-
"This notebook
|
| 11 |
-
"
|
|
|
|
| 12 |
]
|
| 13 |
},
|
| 14 |
{
|
|
@@ -32,10 +33,15 @@
|
|
| 32 |
"outputs": [],
|
| 33 |
"source": [
|
| 34 |
"from pathlib import Path\n",
|
|
|
|
| 35 |
"\n",
|
| 36 |
"import torch\n",
|
|
|
|
| 37 |
"\n",
|
| 38 |
-
"
|
|
|
|
|
|
|
|
|
|
| 39 |
"\n",
|
| 40 |
"print(\"CUDA available:\", torch.cuda.is_available())\n",
|
| 41 |
"if torch.cuda.is_available():\n",
|
|
@@ -53,20 +59,23 @@
|
|
| 53 |
"outputs": [],
|
| 54 |
"source": [
|
| 55 |
"args = make_training_args(\n",
|
| 56 |
-
" model_id=\"Qwen/Qwen3.5-
|
| 57 |
" output_dir=\"artifacts/grpo-h100\",\n",
|
| 58 |
-
" dataset_episodes=
|
| 59 |
-
" rollout_steps=
|
| 60 |
" collection_policy=\"heuristic\",\n",
|
| 61 |
" reward_backend=\"local\",\n",
|
| 62 |
" domain_randomise=True,\n",
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
"
|
|
|
|
|
|
|
| 69 |
" num_train_epochs=1.0,\n",
|
|
|
|
| 70 |
" logging_steps=1,\n",
|
| 71 |
" save_steps=25,\n",
|
| 72 |
" trust_remote_code=True,\n",
|
|
@@ -84,14 +93,18 @@
|
|
| 84 |
"metadata": {},
|
| 85 |
"outputs": [],
|
| 86 |
"source": [
|
| 87 |
-
"
|
| 88 |
-
"
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
-
"
|
| 93 |
-
"
|
|
|
|
|
|
|
|
|
|
| 94 |
")\n",
|
|
|
|
| 95 |
"\n",
|
| 96 |
"print(preview_examples[0][\"prompt\"][:3500])\n",
|
| 97 |
"print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
|
|
|
|
| 7 |
"source": [
|
| 8 |
"# Train A Self-Driving Lab Policy on H100\n",
|
| 9 |
"\n",
|
| 10 |
+
"This notebook trains a GRPO policy for the **same bio-experiment planning task** as `run_agent.py`: choosing structured actions (collect_sample, run_qc, cluster, de_analysis, etc.) step-by-step in the OpenEnv bio-experiment environment.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"**Flow:** Build prompts from `BioExperimentEnvironment` rollouts (same env `run_agent.py` uses) → OpenEnv reward scores actions locally → GRPO trains the model. Uses `build_openenv_reward`, `prepare_prompt_examples`, and `build_grpo_trainer` from `training_script.py`."
|
| 13 |
]
|
| 14 |
},
|
| 15 |
{
|
|
|
|
| 33 |
"outputs": [],
|
| 34 |
"source": [
|
| 35 |
"from pathlib import Path\n",
|
| 36 |
+
"import importlib\n",
|
| 37 |
"\n",
|
| 38 |
"import torch\n",
|
| 39 |
+
"import training_script as training_script_module\n",
|
| 40 |
"\n",
|
| 41 |
+
"training_script_module = importlib.reload(training_script_module)\n",
|
| 42 |
+
"make_training_args = training_script_module.make_training_args\n",
|
| 43 |
+
"prepare_prompt_examples = training_script_module.prepare_prompt_examples\n",
|
| 44 |
+
"run_training = training_script_module.run_training\n",
|
| 45 |
"\n",
|
| 46 |
"print(\"CUDA available:\", torch.cuda.is_available())\n",
|
| 47 |
"if torch.cuda.is_available():\n",
|
|
|
|
| 59 |
"outputs": [],
|
| 60 |
"source": [
|
| 61 |
"args = make_training_args(\n",
|
| 62 |
+
" model_id=\"Qwen/Qwen3.5-9B\",\n",
|
| 63 |
" output_dir=\"artifacts/grpo-h100\",\n",
|
| 64 |
+
" dataset_episodes=64, # more data per run\n",
|
| 65 |
+
" rollout_steps=12, # slightly longer trajectories\n",
|
| 66 |
" collection_policy=\"heuristic\",\n",
|
| 67 |
" reward_backend=\"local\",\n",
|
| 68 |
" domain_randomise=True,\n",
|
| 69 |
+
"\n",
|
| 70 |
+
" num_generations=8, # H100 can handle a larger GRPO group\n",
|
| 71 |
+
" max_completion_length=192, # small bump if completions are being cut off\n",
|
| 72 |
+
" max_prompt_length=1024, # trim a bit unless you truly need 1280\n",
|
| 73 |
+
"\n",
|
| 74 |
+
" per_device_train_batch_size=8, # first thing to try on H100\n",
|
| 75 |
+
" gradient_accumulation_steps=2, # same effective batch as before, fewer sync steps\n",
|
| 76 |
+
" learning_rate=1e-5, # slightly more aggressive for LoRA/QLoRA-style RL tuning\n",
|
| 77 |
" num_train_epochs=1.0,\n",
|
| 78 |
+
"\n",
|
| 79 |
" logging_steps=1,\n",
|
| 80 |
" save_steps=25,\n",
|
| 81 |
" trust_remote_code=True,\n",
|
|
|
|
| 93 |
"metadata": {},
|
| 94 |
"outputs": [],
|
| 95 |
"source": [
|
| 96 |
+
"# Same prompt format run_agent.py sees: SYSTEM_PROMPT + observation\n",
|
| 97 |
+
"preview_data = prepare_prompt_examples(\n",
|
| 98 |
+
" make_training_args(\n",
|
| 99 |
+
" dataset_episodes=1,\n",
|
| 100 |
+
" rollout_steps=args.rollout_steps,\n",
|
| 101 |
+
" collection_policy=args.collection_policy,\n",
|
| 102 |
+
" scenario_name=[\"cardiac_disease_de\"],\n",
|
| 103 |
+
" seed=args.seed,\n",
|
| 104 |
+
" domain_randomise=args.domain_randomise,\n",
|
| 105 |
+
" )\n",
|
| 106 |
")\n",
|
| 107 |
+
"preview_examples = preview_data[\"examples\"]\n",
|
| 108 |
"\n",
|
| 109 |
"print(preview_examples[0][\"prompt\"][:3500])\n",
|
| 110 |
"print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])\n"
|
train_unsloth.ipynb
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Train a Self-Driving Lab Policy with Unsloth\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook uses **Unsloth** for fast quantized training on GPU nodes (e.g. H100). It mirrors `train.ipynb` but loads the model via Unsloth's optimized path with 4-bit quantization and LoRA adapters.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Model**: Uses **Qwen3-4B-Base** by default. Alternatives:\n",
|
| 12 |
+
"- `Qwen/Qwen3-4B-Base` (base, no chat template)\n",
|
| 13 |
+
"- `unsloth/Qwen2.5-3B-Instruct-bnb-4bit`\n",
|
| 14 |
+
"- `unsloth/Qwen2.5-7B-Instruct-bnb-4bit`"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"# Install Unsloth and training dependencies (run once per session)\n",
|
| 24 |
+
"# Option A: uv (if using uv-managed venv)\n",
|
| 25 |
+
"# !uv sync --extra train\n",
|
| 26 |
+
"# !uv pip install unsloth unsloth_zoo --no-deps\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"# Option B: pip\n",
|
| 29 |
+
"%pip install -q -U torch transformers datasets trl accelerate bitsandbytes unsloth unsloth_zoo matplotlib huggingface_hub\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"# Optional extras used by some reward-scoring paths.\n",
|
| 32 |
+
"%pip install -q -U sentence-transformers gseapy"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"execution_count": null,
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"outputs": [],
|
| 40 |
+
"source": [
|
| 41 |
+
"# Unsloth must be imported before trl, transformers, peft\n",
|
| 42 |
+
"import unsloth # noqa: F401\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"from pathlib import Path\n",
|
| 45 |
+
"import torch\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"from training_unsloth import make_training_args, run_training\n",
|
| 48 |
+
"import training_script as base\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"print(\"CUDA available:\", torch.cuda.is_available())\n",
|
| 51 |
+
"if torch.cuda.is_available():\n",
|
| 52 |
+
" print(\"GPU:\", torch.cuda.get_device_name(0))\n",
|
| 53 |
+
" print(\"bf16 supported:\", torch.cuda.is_bf16_supported())\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"Path(\"artifacts\").mkdir(exist_ok=True)"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"args = make_training_args(\n",
|
| 65 |
+
" model_id=\"Qwen/Qwen3-4B-Base\",\n",
|
| 66 |
+
" output_dir=\"artifacts/grpo-unsloth-qwen3-4b\",\n",
|
| 67 |
+
" dataset_episodes=32,\n",
|
| 68 |
+
" rollout_steps=10,\n",
|
| 69 |
+
" collection_policy=\"heuristic\",\n",
|
| 70 |
+
" reward_backend=\"local\",\n",
|
| 71 |
+
" domain_randomise=True,\n",
|
| 72 |
+
" num_generations=4,\n",
|
| 73 |
+
" max_completion_length=160,\n",
|
| 74 |
+
" max_prompt_length=1280,\n",
|
| 75 |
+
" max_seq_length=2048,\n",
|
| 76 |
+
" per_device_train_batch_size=4,\n",
|
| 77 |
+
" gradient_accumulation_steps=4,\n",
|
| 78 |
+
" learning_rate=5e-6,\n",
|
| 79 |
+
" num_train_epochs=1.0,\n",
|
| 80 |
+
" logging_steps=1,\n",
|
| 81 |
+
" save_steps=25,\n",
|
| 82 |
+
" trust_remote_code=True,\n",
|
| 83 |
+
" dry_run=False,\n",
|
| 84 |
+
" seed=42,\n",
|
| 85 |
+
")\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"args"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"preview_examples = base.build_prompt_examples(\n",
|
| 97 |
+
" dataset_episodes=1,\n",
|
| 98 |
+
" rollout_steps=args.rollout_steps,\n",
|
| 99 |
+
" collection_policy=args.collection_policy,\n",
|
| 100 |
+
" scenario_names=[\"cardiac_disease_de\"],\n",
|
| 101 |
+
" seed=args.seed,\n",
|
| 102 |
+
" domain_randomise=args.domain_randomise,\n",
|
| 103 |
+
")\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"print(preview_examples[0][\"prompt\"][:3500])\n",
|
| 106 |
+
"print(\"\\nReference action:\\n\", preview_examples[0][\"reference_action\"])"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": [
|
| 115 |
+
"# Optional smoke test before a full run.\n",
|
| 116 |
+
"dry_run_args = make_training_args(**{**vars(args), \"dry_run\": True})\n",
|
| 117 |
+
"dry_run_result = run_training(dry_run_args)\n",
|
| 118 |
+
"len(dry_run_result[\"examples\"])"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": null,
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [],
|
| 126 |
+
"source": [
|
| 127 |
+
"from IPython.display import Image, display\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"train_result = run_training(args)\n",
|
| 130 |
+
"for name, plot_path in train_result[\"plot_paths\"].items():\n",
|
| 131 |
+
" print(name, plot_path)\n",
|
| 132 |
+
" display(Image(filename=plot_path))"
|
| 133 |
+
]
|
| 134 |
+
}
|
| 135 |
+
],
|
| 136 |
+
"metadata": {
|
| 137 |
+
"kernelspec": {
|
| 138 |
+
"display_name": "Python 3",
|
| 139 |
+
"language": "python",
|
| 140 |
+
"name": "python3"
|
| 141 |
+
},
|
| 142 |
+
"language_info": {
|
| 143 |
+
"name": "python",
|
| 144 |
+
"version": "3.10.0"
|
| 145 |
+
}
|
| 146 |
+
},
|
| 147 |
+
"nbformat": 4,
|
| 148 |
+
"nbformat_minor": 4
|
| 149 |
+
}
|
training_script.py
CHANGED
|
@@ -1266,11 +1266,122 @@ def resolve_torch_runtime() -> Dict[str, Any]:
|
|
| 1266 |
}
|
| 1267 |
|
| 1268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1269 |
def load_model_artifacts(
|
| 1270 |
model_id: str,
|
| 1271 |
*,
|
| 1272 |
trust_remote_code: bool,
|
| 1273 |
):
|
|
|
|
| 1274 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1275 |
|
| 1276 |
runtime = resolve_torch_runtime()
|
|
@@ -1295,6 +1406,109 @@ def load_model_artifacts(
|
|
| 1295 |
return tokenizer, model
|
| 1296 |
|
| 1297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1298 |
def generate_action_with_model(
|
| 1299 |
model: Any,
|
| 1300 |
tokenizer: Any,
|
|
@@ -1363,20 +1577,10 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
|
| 1363 |
"model": model,
|
| 1364 |
}
|
| 1365 |
|
| 1366 |
-
|
| 1367 |
-
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
collection_policy=args.collection_policy,
|
| 1371 |
-
scenario_names=scenario_names,
|
| 1372 |
-
seed=args.seed,
|
| 1373 |
-
domain_randomise=args.domain_randomise,
|
| 1374 |
-
)
|
| 1375 |
-
reward_fn = OpenEnvReward(
|
| 1376 |
-
reward_backend=args.reward_backend,
|
| 1377 |
-
base_url=args.base_url,
|
| 1378 |
-
domain_randomise=args.domain_randomise,
|
| 1379 |
-
)
|
| 1380 |
|
| 1381 |
if args.dry_run:
|
| 1382 |
run_dry_run_preview(examples, reward_fn, args.output_dir)
|
|
@@ -1389,41 +1593,30 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
|
| 1389 |
}
|
| 1390 |
|
| 1391 |
from datasets import Dataset
|
| 1392 |
-
from trl import GRPOConfig, GRPOTrainer
|
| 1393 |
-
|
| 1394 |
train_dataset = Dataset.from_list(examples)
|
| 1395 |
tokenizer, model = load_model_artifacts(
|
| 1396 |
args.model_id,
|
| 1397 |
trust_remote_code=args.trust_remote_code,
|
| 1398 |
)
|
| 1399 |
-
config = GRPOConfig(
|
| 1400 |
-
output_dir=args.output_dir,
|
| 1401 |
-
learning_rate=args.learning_rate,
|
| 1402 |
-
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 1403 |
-
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 1404 |
-
num_generations=args.num_generations,
|
| 1405 |
-
max_completion_length=args.max_completion_length,
|
| 1406 |
-
num_train_epochs=args.num_train_epochs,
|
| 1407 |
-
logging_steps=args.logging_steps,
|
| 1408 |
-
save_steps=args.save_steps,
|
| 1409 |
-
bf16=runtime["bf16"],
|
| 1410 |
-
fp16=runtime["fp16"],
|
| 1411 |
-
report_to="none",
|
| 1412 |
-
remove_unused_columns=False,
|
| 1413 |
-
)
|
| 1414 |
|
| 1415 |
print(
|
| 1416 |
f"Training runtime: device={runtime['device']} "
|
| 1417 |
f"name={runtime['device_name']} "
|
| 1418 |
f"dtype={runtime['dtype']}"
|
| 1419 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1420 |
|
| 1421 |
-
trainer =
|
| 1422 |
model=model,
|
| 1423 |
-
reward_funcs=reward_fn,
|
| 1424 |
-
args=config,
|
| 1425 |
train_dataset=train_dataset,
|
| 1426 |
-
|
|
|
|
|
|
|
|
|
|
| 1427 |
)
|
| 1428 |
trainer.train()
|
| 1429 |
trainer.save_model(args.output_dir)
|
|
|
|
| 1266 |
}
|
| 1267 |
|
| 1268 |
|
| 1269 |
+
def _guard_invalid_torchao_version() -> None:
|
| 1270 |
+
"""Treat malformed torchao installs as unavailable for HF imports."""
|
| 1271 |
+
import functools
|
| 1272 |
+
import importlib.metadata as importlib_metadata
|
| 1273 |
+
import sys
|
| 1274 |
+
from packaging.version import InvalidVersion, Version
|
| 1275 |
+
|
| 1276 |
+
if getattr(importlib_metadata, "_openenv_torchao_guard_installed", False):
|
| 1277 |
+
metadata_guard_installed = True
|
| 1278 |
+
else:
|
| 1279 |
+
original_version = importlib_metadata.version
|
| 1280 |
+
|
| 1281 |
+
def guarded_version(distribution_name: str) -> str:
|
| 1282 |
+
version = original_version(distribution_name)
|
| 1283 |
+
if distribution_name.lower() == "torchao":
|
| 1284 |
+
try:
|
| 1285 |
+
Version(version)
|
| 1286 |
+
except InvalidVersion as exc:
|
| 1287 |
+
raise importlib_metadata.PackageNotFoundError(
|
| 1288 |
+
f"Malformed torchao version metadata: {version!r}"
|
| 1289 |
+
) from exc
|
| 1290 |
+
return version
|
| 1291 |
+
|
| 1292 |
+
importlib_metadata.version = guarded_version
|
| 1293 |
+
importlib_metadata._openenv_torchao_guard_installed = True
|
| 1294 |
+
metadata_guard_installed = False
|
| 1295 |
+
|
| 1296 |
+
import_utils = sys.modules.get("transformers.utils.import_utils")
|
| 1297 |
+
if import_utils is not None and not getattr(import_utils, "_openenv_torchao_guard_installed", False):
|
| 1298 |
+
original_is_package_available = import_utils._is_package_available
|
| 1299 |
+
|
| 1300 |
+
def guarded_is_package_available(
|
| 1301 |
+
pkg_name: str,
|
| 1302 |
+
return_version: bool = False,
|
| 1303 |
+
):
|
| 1304 |
+
if pkg_name != "torchao":
|
| 1305 |
+
return original_is_package_available(pkg_name, return_version=return_version)
|
| 1306 |
+
is_available, package_version = original_is_package_available(
|
| 1307 |
+
pkg_name,
|
| 1308 |
+
return_version=True,
|
| 1309 |
+
)
|
| 1310 |
+
if not is_available:
|
| 1311 |
+
return (False, package_version) if return_version else (False, None)
|
| 1312 |
+
try:
|
| 1313 |
+
Version(package_version)
|
| 1314 |
+
except InvalidVersion:
|
| 1315 |
+
return (False, "0") if return_version else (False, None)
|
| 1316 |
+
return (True, package_version) if return_version else (True, None)
|
| 1317 |
+
|
| 1318 |
+
min_version = getattr(import_utils, "TORCHAO_MIN_VERSION", "0")
|
| 1319 |
+
|
| 1320 |
+
@functools.lru_cache
|
| 1321 |
+
def guarded_is_torchao_available(min_version_override: str = min_version) -> bool:
|
| 1322 |
+
is_available, package_version = guarded_is_package_available(
|
| 1323 |
+
"torchao",
|
| 1324 |
+
return_version=True,
|
| 1325 |
+
)
|
| 1326 |
+
if not is_available:
|
| 1327 |
+
return False
|
| 1328 |
+
try:
|
| 1329 |
+
return Version(package_version) >= Version(min_version_override)
|
| 1330 |
+
except InvalidVersion:
|
| 1331 |
+
return False
|
| 1332 |
+
|
| 1333 |
+
if hasattr(import_utils.is_torchao_available, "cache_clear"):
|
| 1334 |
+
import_utils.is_torchao_available.cache_clear()
|
| 1335 |
+
import_utils._is_package_available = guarded_is_package_available
|
| 1336 |
+
import_utils.is_torchao_available = guarded_is_torchao_available
|
| 1337 |
+
import_utils._openenv_torchao_guard_installed = True
|
| 1338 |
+
|
| 1339 |
+
transformers_utils = sys.modules.get("transformers.utils")
|
| 1340 |
+
if transformers_utils is not None:
|
| 1341 |
+
transformers_utils.is_torchao_available = guarded_is_torchao_available
|
| 1342 |
+
|
| 1343 |
+
if metadata_guard_installed and import_utils is None:
|
| 1344 |
+
return
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
def _guard_partial_vllm_install() -> None:
|
| 1348 |
+
"""Treat partial vLLM installs as unavailable for TRL imports."""
|
| 1349 |
+
import functools
|
| 1350 |
+
import importlib
|
| 1351 |
+
|
| 1352 |
+
try:
|
| 1353 |
+
import trl.import_utils as trl_import_utils
|
| 1354 |
+
except Exception:
|
| 1355 |
+
return
|
| 1356 |
+
|
| 1357 |
+
if getattr(trl_import_utils, "_openenv_vllm_guard_installed", False):
|
| 1358 |
+
return
|
| 1359 |
+
|
| 1360 |
+
def _has_usable_vllm() -> bool:
|
| 1361 |
+
try:
|
| 1362 |
+
importlib.import_module("vllm")
|
| 1363 |
+
importlib.import_module("vllm.distributed.device_communicators.pynccl")
|
| 1364 |
+
importlib.import_module("vllm.distributed.utils")
|
| 1365 |
+
except Exception:
|
| 1366 |
+
return False
|
| 1367 |
+
return True
|
| 1368 |
+
|
| 1369 |
+
@functools.lru_cache
|
| 1370 |
+
def guarded_is_vllm_available(*args: Any, **kwargs: Any) -> bool:
|
| 1371 |
+
return _has_usable_vllm()
|
| 1372 |
+
|
| 1373 |
+
if hasattr(trl_import_utils.is_vllm_available, "cache_clear"):
|
| 1374 |
+
trl_import_utils.is_vllm_available.cache_clear()
|
| 1375 |
+
trl_import_utils.is_vllm_available = guarded_is_vllm_available
|
| 1376 |
+
trl_import_utils._openenv_vllm_guard_installed = True
|
| 1377 |
+
|
| 1378 |
+
|
| 1379 |
def load_model_artifacts(
|
| 1380 |
model_id: str,
|
| 1381 |
*,
|
| 1382 |
trust_remote_code: bool,
|
| 1383 |
):
|
| 1384 |
+
_guard_invalid_torchao_version()
|
| 1385 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1386 |
|
| 1387 |
runtime = resolve_torch_runtime()
|
|
|
|
| 1406 |
return tokenizer, model
|
| 1407 |
|
| 1408 |
|
| 1409 |
+
def build_openenv_reward(args: argparse.Namespace) -> OpenEnvReward:
|
| 1410 |
+
"""Return the OpenEnv-compatible reward callable used by GRPO."""
|
| 1411 |
+
return OpenEnvReward(
|
| 1412 |
+
reward_backend=args.reward_backend,
|
| 1413 |
+
base_url=args.base_url,
|
| 1414 |
+
domain_randomise=args.domain_randomise,
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
|
| 1418 |
+
def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
|
| 1419 |
+
"""Build the OpenEnv rollout states that seed GRPO prompts."""
|
| 1420 |
+
scenario_names = selected_scenarios(args.scenario_name)
|
| 1421 |
+
examples = build_prompt_examples(
|
| 1422 |
+
dataset_episodes=args.dataset_episodes,
|
| 1423 |
+
rollout_steps=args.rollout_steps,
|
| 1424 |
+
collection_policy=args.collection_policy,
|
| 1425 |
+
scenario_names=scenario_names,
|
| 1426 |
+
seed=args.seed,
|
| 1427 |
+
domain_randomise=args.domain_randomise,
|
| 1428 |
+
)
|
| 1429 |
+
return {
|
| 1430 |
+
"scenario_names": scenario_names,
|
| 1431 |
+
"examples": examples,
|
| 1432 |
+
}
|
| 1433 |
+
|
| 1434 |
+
|
| 1435 |
+
def build_grpo_config(
|
| 1436 |
+
args: argparse.Namespace,
|
| 1437 |
+
runtime: Dict[str, Any],
|
| 1438 |
+
):
|
| 1439 |
+
import inspect
|
| 1440 |
+
|
| 1441 |
+
_guard_invalid_torchao_version()
|
| 1442 |
+
_guard_partial_vllm_install()
|
| 1443 |
+
from trl import GRPOConfig
|
| 1444 |
+
|
| 1445 |
+
config_kwargs = {
|
| 1446 |
+
"output_dir": args.output_dir,
|
| 1447 |
+
"learning_rate": args.learning_rate,
|
| 1448 |
+
"per_device_train_batch_size": args.per_device_train_batch_size,
|
| 1449 |
+
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| 1450 |
+
"num_generations": args.num_generations,
|
| 1451 |
+
"max_completion_length": args.max_completion_length,
|
| 1452 |
+
"max_prompt_length": args.max_prompt_length,
|
| 1453 |
+
"num_train_epochs": args.num_train_epochs,
|
| 1454 |
+
"logging_steps": args.logging_steps,
|
| 1455 |
+
"save_steps": args.save_steps,
|
| 1456 |
+
"bf16": runtime["bf16"],
|
| 1457 |
+
"fp16": runtime["fp16"],
|
| 1458 |
+
"report_to": "none",
|
| 1459 |
+
"remove_unused_columns": False,
|
| 1460 |
+
}
|
| 1461 |
+
supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
|
| 1462 |
+
|
| 1463 |
+
# Older TRL builds may expose a single max_length knob instead of
|
| 1464 |
+
# separate prompt/completion limits.
|
| 1465 |
+
if (
|
| 1466 |
+
"max_length" in supported_params
|
| 1467 |
+
and "max_prompt_length" not in supported_params
|
| 1468 |
+
and "max_completion_length" not in supported_params
|
| 1469 |
+
):
|
| 1470 |
+
config_kwargs["max_length"] = (
|
| 1471 |
+
args.max_prompt_length + args.max_completion_length
|
| 1472 |
+
)
|
| 1473 |
+
|
| 1474 |
+
filtered_kwargs = {
|
| 1475 |
+
key: value
|
| 1476 |
+
for key, value in config_kwargs.items()
|
| 1477 |
+
if key in supported_params
|
| 1478 |
+
}
|
| 1479 |
+
skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
|
| 1480 |
+
if skipped:
|
| 1481 |
+
print(
|
| 1482 |
+
"GRPOConfig compatibility: skipping unsupported fields "
|
| 1483 |
+
f"{', '.join(skipped)}"
|
| 1484 |
+
)
|
| 1485 |
+
|
| 1486 |
+
return GRPOConfig(**filtered_kwargs)
|
| 1487 |
+
|
| 1488 |
+
|
| 1489 |
+
def build_grpo_trainer(
|
| 1490 |
+
*,
|
| 1491 |
+
model: Any,
|
| 1492 |
+
tokenizer: Any,
|
| 1493 |
+
reward_func: Any,
|
| 1494 |
+
train_dataset: Any,
|
| 1495 |
+
args: argparse.Namespace,
|
| 1496 |
+
runtime: Dict[str, Any],
|
| 1497 |
+
):
|
| 1498 |
+
_guard_invalid_torchao_version()
|
| 1499 |
+
_guard_partial_vllm_install()
|
| 1500 |
+
from trl import GRPOTrainer
|
| 1501 |
+
|
| 1502 |
+
config = build_grpo_config(args, runtime)
|
| 1503 |
+
return GRPOTrainer(
|
| 1504 |
+
model=model,
|
| 1505 |
+
reward_funcs=reward_func,
|
| 1506 |
+
args=config,
|
| 1507 |
+
train_dataset=train_dataset,
|
| 1508 |
+
processing_class=tokenizer,
|
| 1509 |
+
)
|
| 1510 |
+
|
| 1511 |
+
|
| 1512 |
def generate_action_with_model(
|
| 1513 |
model: Any,
|
| 1514 |
tokenizer: Any,
|
|
|
|
| 1577 |
"model": model,
|
| 1578 |
}
|
| 1579 |
|
| 1580 |
+
prompt_data = prepare_prompt_examples(args)
|
| 1581 |
+
scenario_names = prompt_data["scenario_names"]
|
| 1582 |
+
examples = prompt_data["examples"]
|
| 1583 |
+
reward_fn = build_openenv_reward(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1584 |
|
| 1585 |
if args.dry_run:
|
| 1586 |
run_dry_run_preview(examples, reward_fn, args.output_dir)
|
|
|
|
| 1593 |
}
|
| 1594 |
|
| 1595 |
from datasets import Dataset
|
|
|
|
|
|
|
| 1596 |
train_dataset = Dataset.from_list(examples)
|
| 1597 |
tokenizer, model = load_model_artifacts(
|
| 1598 |
args.model_id,
|
| 1599 |
trust_remote_code=args.trust_remote_code,
|
| 1600 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1601 |
|
| 1602 |
print(
|
| 1603 |
f"Training runtime: device={runtime['device']} "
|
| 1604 |
f"name={runtime['device_name']} "
|
| 1605 |
f"dtype={runtime['dtype']}"
|
| 1606 |
)
|
| 1607 |
+
print(
|
| 1608 |
+
"OpenEnv reward: "
|
| 1609 |
+
f"backend={args.reward_backend} scenarios={len(scenario_names)} "
|
| 1610 |
+
f"examples={len(examples)}"
|
| 1611 |
+
)
|
| 1612 |
|
| 1613 |
+
trainer = build_grpo_trainer(
|
| 1614 |
model=model,
|
|
|
|
|
|
|
| 1615 |
train_dataset=train_dataset,
|
| 1616 |
+
tokenizer=tokenizer,
|
| 1617 |
+
reward_func=reward_fn,
|
| 1618 |
+
args=args,
|
| 1619 |
+
runtime=runtime,
|
| 1620 |
)
|
| 1621 |
trainer.train()
|
| 1622 |
trainer.save_model(args.output_dir)
|
training_unsloth.py
CHANGED
|
@@ -6,6 +6,9 @@ but arranges the Unsloth path in the more typical pattern:
|
|
| 6 |
2. load a quantized model
|
| 7 |
3. apply LoRA adapters
|
| 8 |
4. train with an explicit OpenEnv reward function
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
|
@@ -15,6 +18,9 @@ import random
|
|
| 15 |
from pathlib import Path
|
| 16 |
from typing import Any, Dict, Optional, Sequence
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
import training_script as base
|
| 19 |
|
| 20 |
DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
|
|
@@ -36,11 +42,21 @@ LORA_TARGET_MODULES = [
|
|
| 36 |
def require_unsloth():
|
| 37 |
try:
|
| 38 |
from unsloth import FastLanguageModel, PatchFastRL
|
| 39 |
-
except ImportError as exc:
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return FastLanguageModel, PatchFastRL
|
| 45 |
|
| 46 |
|
|
@@ -74,11 +90,6 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
|
| 74 |
action="store_true",
|
| 75 |
help="Disable 4-bit quantized loading and use the wider base weights.",
|
| 76 |
)
|
| 77 |
-
parser.add_argument(
|
| 78 |
-
"--disable-fast-inference",
|
| 79 |
-
action="store_true",
|
| 80 |
-
help="Disable Unsloth fast inference kernels where supported.",
|
| 81 |
-
)
|
| 82 |
parser.add_argument(
|
| 83 |
"--lora-r",
|
| 84 |
type=int,
|
|
@@ -125,7 +136,7 @@ def load_model_artifacts(
|
|
| 125 |
trust_remote_code: bool,
|
| 126 |
max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
|
| 127 |
load_in_4bit: bool = True,
|
| 128 |
-
fast_inference: bool =
|
| 129 |
prepare_for_inference: bool = False,
|
| 130 |
):
|
| 131 |
FastLanguageModel, _ = require_unsloth()
|
|
@@ -136,7 +147,7 @@ def load_model_artifacts(
|
|
| 136 |
FastLanguageModel,
|
| 137 |
model_name=model_id,
|
| 138 |
max_seq_length=max_seq_length,
|
| 139 |
-
dtype="
|
| 140 |
load_in_4bit=load_in_4bit,
|
| 141 |
fast_inference=fast_inference,
|
| 142 |
trust_remote_code=trust_remote_code,
|
|
@@ -211,23 +222,43 @@ def build_grpo_config(
|
|
| 211 |
args: argparse.Namespace,
|
| 212 |
runtime: Dict[str, Any],
|
| 213 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
from trl import GRPOConfig
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
def build_unsloth_grpo_trainer(
|
|
@@ -239,6 +270,8 @@ def build_unsloth_grpo_trainer(
|
|
| 239 |
args: argparse.Namespace,
|
| 240 |
runtime: Dict[str, Any],
|
| 241 |
):
|
|
|
|
|
|
|
| 242 |
from trl import GRPOTrainer
|
| 243 |
|
| 244 |
config = build_grpo_config(args, runtime)
|
|
@@ -311,7 +344,7 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
|
| 311 |
trust_remote_code=args.trust_remote_code,
|
| 312 |
max_seq_length=args.max_seq_length,
|
| 313 |
load_in_4bit=not args.disable_4bit,
|
| 314 |
-
fast_inference=
|
| 315 |
prepare_for_inference=True,
|
| 316 |
)
|
| 317 |
device = getattr(model, "device", "unknown")
|
|
@@ -352,7 +385,7 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
|
| 352 |
trust_remote_code=args.trust_remote_code,
|
| 353 |
max_seq_length=args.max_seq_length,
|
| 354 |
load_in_4bit=not args.disable_4bit,
|
| 355 |
-
fast_inference=
|
| 356 |
)
|
| 357 |
# 2. Apply LoRA adapters.
|
| 358 |
model = apply_lora_adapters(FastLanguageModel, model, args)
|
|
@@ -378,6 +411,11 @@ def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
|
| 378 |
args=args,
|
| 379 |
runtime=runtime,
|
| 380 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
trainer.train()
|
| 382 |
trainer.save_model(args.output_dir)
|
| 383 |
tokenizer.save_pretrained(args.output_dir)
|
|
|
|
| 6 |
2. load a quantized model
|
| 7 |
3. apply LoRA adapters
|
| 8 |
4. train with an explicit OpenEnv reward function
|
| 9 |
+
|
| 10 |
+
NOTE: Unsloth must be imported before trl, transformers, peft. Import this
|
| 11 |
+
module before training_script.
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
|
|
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import Any, Dict, Optional, Sequence
|
| 20 |
|
| 21 |
+
# Unsloth must be imported before trl/transformers/peft for optimizations.
|
| 22 |
+
import unsloth # noqa: F401
|
| 23 |
+
|
| 24 |
import training_script as base
|
| 25 |
|
| 26 |
DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
|
|
|
|
| 42 |
def require_unsloth():
|
| 43 |
try:
|
| 44 |
from unsloth import FastLanguageModel, PatchFastRL
|
| 45 |
+
except ImportError as exc:
|
| 46 |
+
msg = str(exc)
|
| 47 |
+
if "vllm.lora" in msg or "vllm" in msg.lower():
|
| 48 |
+
raise RuntimeError(
|
| 49 |
+
f"Unsloth failed: {exc}. "
|
| 50 |
+
"unsloth_zoo expects vllm.lora.models. Install a compatible vllm:\n"
|
| 51 |
+
" pip install 'vllm==0.8.2' # requires torch 2.6\n"
|
| 52 |
+
" pip install 'vllm==0.7.3' # alternative\n"
|
| 53 |
+
"If torch>=2.10 conflicts, use a separate env with torch 2.6–2.8."
|
| 54 |
+
) from exc
|
| 55 |
+
if "unsloth" in msg.lower():
|
| 56 |
+
raise RuntimeError(
|
| 57 |
+
"Unsloth is not installed. Run `uv sync` or `pip install unsloth`."
|
| 58 |
+
) from exc
|
| 59 |
+
raise RuntimeError(f"Failed to import Unsloth: {exc}") from exc
|
| 60 |
return FastLanguageModel, PatchFastRL
|
| 61 |
|
| 62 |
|
|
|
|
| 90 |
action="store_true",
|
| 91 |
help="Disable 4-bit quantized loading and use the wider base weights.",
|
| 92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
parser.add_argument(
|
| 94 |
"--lora-r",
|
| 95 |
type=int,
|
|
|
|
| 136 |
trust_remote_code: bool,
|
| 137 |
max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
|
| 138 |
load_in_4bit: bool = True,
|
| 139 |
+
fast_inference: bool = False,
|
| 140 |
prepare_for_inference: bool = False,
|
| 141 |
):
|
| 142 |
FastLanguageModel, _ = require_unsloth()
|
|
|
|
| 147 |
FastLanguageModel,
|
| 148 |
model_name=model_id,
|
| 149 |
max_seq_length=max_seq_length,
|
| 150 |
+
dtype=runtime["dtype"],
|
| 151 |
load_in_4bit=load_in_4bit,
|
| 152 |
fast_inference=fast_inference,
|
| 153 |
trust_remote_code=trust_remote_code,
|
|
|
|
| 222 |
args: argparse.Namespace,
|
| 223 |
runtime: Dict[str, Any],
|
| 224 |
):
|
| 225 |
+
import inspect
|
| 226 |
+
|
| 227 |
+
base._guard_invalid_torchao_version()
|
| 228 |
+
base._guard_partial_vllm_install()
|
| 229 |
from trl import GRPOConfig
|
| 230 |
|
| 231 |
+
supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
|
| 232 |
+
config_kwargs = {
|
| 233 |
+
"output_dir": args.output_dir,
|
| 234 |
+
"learning_rate": args.learning_rate,
|
| 235 |
+
"per_device_train_batch_size": args.per_device_train_batch_size,
|
| 236 |
+
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| 237 |
+
"num_generations": args.num_generations,
|
| 238 |
+
"max_completion_length": args.max_completion_length,
|
| 239 |
+
"num_train_epochs": args.num_train_epochs,
|
| 240 |
+
"logging_steps": args.logging_steps,
|
| 241 |
+
"save_steps": args.save_steps,
|
| 242 |
+
"bf16": runtime["bf16"],
|
| 243 |
+
"fp16": runtime["fp16"],
|
| 244 |
+
"report_to": "none",
|
| 245 |
+
"remove_unused_columns": False,
|
| 246 |
+
}
|
| 247 |
+
# Only add max_prompt_length if this TRL version supports it; UnslothGRPOTrainer can
|
| 248 |
+
# fail when passing it to parent, so we only pass when explicitly supported.
|
| 249 |
+
if "max_prompt_length" in supported_params:
|
| 250 |
+
config_kwargs["max_prompt_length"] = None # text-only; avoids image_token_id crash
|
| 251 |
+
if (
|
| 252 |
+
"max_length" in supported_params
|
| 253 |
+
and "max_prompt_length" not in supported_params
|
| 254 |
+
and "max_completion_length" not in supported_params
|
| 255 |
+
):
|
| 256 |
+
config_kwargs["max_length"] = getattr(args, "max_prompt_length", 1024) + args.max_completion_length
|
| 257 |
+
filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}
|
| 258 |
+
skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
|
| 259 |
+
if skipped:
|
| 260 |
+
print(f"GRPOConfig compatibility: skipping unsupported fields {', '.join(skipped)}")
|
| 261 |
+
return GRPOConfig(**filtered_kwargs)
|
| 262 |
|
| 263 |
|
| 264 |
def build_unsloth_grpo_trainer(
|
|
|
|
| 270 |
args: argparse.Namespace,
|
| 271 |
runtime: Dict[str, Any],
|
| 272 |
):
|
| 273 |
+
base._guard_invalid_torchao_version()
|
| 274 |
+
base._guard_partial_vllm_install()
|
| 275 |
from trl import GRPOTrainer
|
| 276 |
|
| 277 |
config = build_grpo_config(args, runtime)
|
|
|
|
| 344 |
trust_remote_code=args.trust_remote_code,
|
| 345 |
max_seq_length=args.max_seq_length,
|
| 346 |
load_in_4bit=not args.disable_4bit,
|
| 347 |
+
fast_inference=False,
|
| 348 |
prepare_for_inference=True,
|
| 349 |
)
|
| 350 |
device = getattr(model, "device", "unknown")
|
|
|
|
| 385 |
trust_remote_code=args.trust_remote_code,
|
| 386 |
max_seq_length=args.max_seq_length,
|
| 387 |
load_in_4bit=not args.disable_4bit,
|
| 388 |
+
fast_inference=False,
|
| 389 |
)
|
| 390 |
# 2. Apply LoRA adapters.
|
| 391 |
model = apply_lora_adapters(FastLanguageModel, model, args)
|
|
|
|
| 411 |
args=args,
|
| 412 |
runtime=runtime,
|
| 413 |
)
|
| 414 |
+
# Workaround: UnslothGRPOTrainer expects vision token IDs for max_prompt_length
|
| 415 |
+
# truncation; text-only models don't have them. Set to None so protected=[].
|
| 416 |
+
for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
|
| 417 |
+
if not hasattr(trainer, attr):
|
| 418 |
+
setattr(trainer, attr, None)
|
| 419 |
trainer.train()
|
| 420 |
trainer.save_model(args.output_dir)
|
| 421 |
tokenizer.save_pretrained(args.output_dir)
|
unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py
ADDED
|
File without changes
|
unsloth_compiled_cache/UnslothCPOTrainer.py
ADDED
|
@@ -0,0 +1,1936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothCPOConfig(CPOConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`CPOTrainer`].
|
| 328 |
+
|
| 329 |
+
This class includes only the parameters that are specific to CPO training. For a full list of training arguments,
|
| 330 |
+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
| 331 |
+
differ from those in [`~transformers.TrainingArguments`].
|
| 332 |
+
|
| 333 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 334 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 335 |
+
command line.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 339 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 340 |
+
to use the default data collator.
|
| 341 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 342 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 343 |
+
max_completion_length (`int`, *optional*):
|
| 344 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 345 |
+
and your model is an encoder-decoder.
|
| 346 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 347 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 348 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
| 349 |
+
the [paper](https://huggingface.co/papers/2310.12036).
|
| 350 |
+
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
| 351 |
+
Label smoothing factor. This argument is required if you want to use the default data collator.
|
| 352 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
| 353 |
+
Type of loss to use. Possible values are:
|
| 354 |
+
|
| 355 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
| 356 |
+
- `"hinge"`: hinge loss on the normalized likelihood from the
|
| 357 |
+
[SLiC](https://huggingface.co/papers/2305.10425) paper.
|
| 358 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
| 359 |
+
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
|
| 360 |
+
- `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This
|
| 361 |
+
automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`.
|
| 362 |
+
|
| 363 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 364 |
+
Whether to disable dropout in the model.
|
| 365 |
+
cpo_alpha (`float`, *optional*, defaults to `1.0`):
|
| 366 |
+
Weight of the BC regularizer in CPO training.
|
| 367 |
+
simpo_gamma (`float`, *optional*, defaults to `0.5`):
|
| 368 |
+
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
|
| 369 |
+
alpha (`float`, *optional*, defaults to `0.0`):
|
| 370 |
+
Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses
|
| 371 |
+
standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha))
|
| 372 |
+
/ alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all
|
| 373 |
+
loss types.
|
| 374 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 375 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
| 376 |
+
padding_value (`int`, *optional*):
|
| 377 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 378 |
+
truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
|
| 379 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 380 |
+
This argument is required if you want to use the default data collator.
|
| 381 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 382 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
| 383 |
+
is_encoder_decoder (`bool`, *optional*):
|
| 384 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 385 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 386 |
+
model_init_kwargs (`dict[str, Any]`, *optional*):
|
| 387 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 388 |
+
string.
|
| 389 |
+
dataset_num_proc (`int`, *optional*):
|
| 390 |
+
Number of processes to use for processing the dataset.
|
| 391 |
+
|
| 392 |
+
"""
|
| 393 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 394 |
+
default = None,
|
| 395 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 396 |
+
)
|
| 397 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 398 |
+
default = -1,
|
| 399 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 400 |
+
)
|
| 401 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 402 |
+
default = None,
|
| 403 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 404 |
+
)
|
| 405 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 406 |
+
default = None,
|
| 407 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 408 |
+
)
|
| 409 |
+
max_seq_length : Optional[int] = field(
|
| 410 |
+
default = None,
|
| 411 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 412 |
+
)
|
| 413 |
+
def __init__(
|
| 414 |
+
self,
|
| 415 |
+
output_dir = None,
|
| 416 |
+
per_device_train_batch_size = 4,
|
| 417 |
+
num_train_epochs = 3.0,
|
| 418 |
+
max_steps = -1,
|
| 419 |
+
learning_rate = 5e-05,
|
| 420 |
+
lr_scheduler_type = 'linear',
|
| 421 |
+
lr_scheduler_kwargs = None,
|
| 422 |
+
warmup_steps = 0.1,
|
| 423 |
+
optim = 'adamw_8bit',
|
| 424 |
+
optim_args = None,
|
| 425 |
+
weight_decay = 0.01,
|
| 426 |
+
adam_beta1 = 0.9,
|
| 427 |
+
adam_beta2 = 0.999,
|
| 428 |
+
adam_epsilon = 1e-08,
|
| 429 |
+
optim_target_modules = None,
|
| 430 |
+
gradient_accumulation_steps = 2,
|
| 431 |
+
average_tokens_across_devices = True,
|
| 432 |
+
max_grad_norm = 1.0,
|
| 433 |
+
label_smoothing_factor = 0.0,
|
| 434 |
+
bf16 = False,
|
| 435 |
+
fp16 = False,
|
| 436 |
+
bf16_full_eval = False,
|
| 437 |
+
fp16_full_eval = False,
|
| 438 |
+
tf32 = None,
|
| 439 |
+
gradient_checkpointing = True,
|
| 440 |
+
gradient_checkpointing_kwargs = None,
|
| 441 |
+
torch_compile = False,
|
| 442 |
+
torch_compile_backend = None,
|
| 443 |
+
torch_compile_mode = None,
|
| 444 |
+
use_liger_kernel = False,
|
| 445 |
+
liger_kernel_config = None,
|
| 446 |
+
use_cache = False,
|
| 447 |
+
neftune_noise_alpha = None,
|
| 448 |
+
torch_empty_cache_steps = 250,
|
| 449 |
+
auto_find_batch_size = False,
|
| 450 |
+
logging_strategy = 'steps',
|
| 451 |
+
logging_steps = 1,
|
| 452 |
+
logging_first_step = False,
|
| 453 |
+
log_on_each_node = True,
|
| 454 |
+
logging_nan_inf_filter = False,
|
| 455 |
+
include_num_input_tokens_seen = False,
|
| 456 |
+
log_level = 'passive',
|
| 457 |
+
log_level_replica = 'warning',
|
| 458 |
+
disable_tqdm = None,
|
| 459 |
+
report_to = 'none',
|
| 460 |
+
run_name = None,
|
| 461 |
+
project = 'huggingface',
|
| 462 |
+
trackio_space_id = 'trackio',
|
| 463 |
+
eval_strategy = 'no',
|
| 464 |
+
eval_steps = None,
|
| 465 |
+
eval_delay = 0,
|
| 466 |
+
per_device_eval_batch_size = 4,
|
| 467 |
+
prediction_loss_only = False,
|
| 468 |
+
eval_on_start = False,
|
| 469 |
+
eval_do_concat_batches = True,
|
| 470 |
+
eval_use_gather_object = False,
|
| 471 |
+
eval_accumulation_steps = 2,
|
| 472 |
+
batch_eval_metrics = False,
|
| 473 |
+
save_only_model = False,
|
| 474 |
+
save_strategy = 'steps',
|
| 475 |
+
save_steps = 500,
|
| 476 |
+
save_on_each_node = False,
|
| 477 |
+
save_total_limit = None,
|
| 478 |
+
enable_jit_checkpoint = False,
|
| 479 |
+
push_to_hub = False,
|
| 480 |
+
hub_token = None,
|
| 481 |
+
hub_private_repo = None,
|
| 482 |
+
hub_model_id = None,
|
| 483 |
+
hub_strategy = 'every_save',
|
| 484 |
+
hub_always_push = False,
|
| 485 |
+
hub_revision = None,
|
| 486 |
+
load_best_model_at_end = False,
|
| 487 |
+
metric_for_best_model = None,
|
| 488 |
+
greater_is_better = None,
|
| 489 |
+
ignore_data_skip = False,
|
| 490 |
+
restore_callback_states_from_checkpoint = False,
|
| 491 |
+
full_determinism = False,
|
| 492 |
+
seed = 3407,
|
| 493 |
+
data_seed = 3407,
|
| 494 |
+
use_cpu = False,
|
| 495 |
+
accelerator_config = None,
|
| 496 |
+
parallelism_config = None,
|
| 497 |
+
dataloader_drop_last = False,
|
| 498 |
+
dataloader_num_workers = 0,
|
| 499 |
+
dataloader_pin_memory = True,
|
| 500 |
+
dataloader_persistent_workers = False,
|
| 501 |
+
dataloader_prefetch_factor = None,
|
| 502 |
+
remove_unused_columns = True,
|
| 503 |
+
label_names = None,
|
| 504 |
+
train_sampling_strategy = 'random',
|
| 505 |
+
length_column_name = 'length',
|
| 506 |
+
ddp_find_unused_parameters = None,
|
| 507 |
+
ddp_bucket_cap_mb = None,
|
| 508 |
+
ddp_broadcast_buffers = None,
|
| 509 |
+
ddp_backend = None,
|
| 510 |
+
ddp_timeout = 1800,
|
| 511 |
+
fsdp = None,
|
| 512 |
+
fsdp_config = None,
|
| 513 |
+
deepspeed = None,
|
| 514 |
+
debug = '',
|
| 515 |
+
skip_memory_metrics = True,
|
| 516 |
+
do_train = False,
|
| 517 |
+
do_eval = False,
|
| 518 |
+
do_predict = False,
|
| 519 |
+
resume_from_checkpoint = None,
|
| 520 |
+
warmup_ratio = None,
|
| 521 |
+
logging_dir = None,
|
| 522 |
+
local_rank = -1,
|
| 523 |
+
max_length = 1024,
|
| 524 |
+
max_prompt_length = 512,
|
| 525 |
+
max_completion_length = None,
|
| 526 |
+
beta = 0.1,
|
| 527 |
+
label_smoothing = 0.0,
|
| 528 |
+
loss_type = 'sigmoid',
|
| 529 |
+
disable_dropout = True,
|
| 530 |
+
cpo_alpha = 1.0,
|
| 531 |
+
simpo_gamma = 0.5,
|
| 532 |
+
alpha = 0.0,
|
| 533 |
+
label_pad_token_id = -100,
|
| 534 |
+
padding_value = None,
|
| 535 |
+
truncation_mode = 'keep_end',
|
| 536 |
+
generate_during_eval = False,
|
| 537 |
+
is_encoder_decoder = None,
|
| 538 |
+
model_init_kwargs = None,
|
| 539 |
+
dataset_num_proc = None,
|
| 540 |
+
vllm_sampling_params = None,
|
| 541 |
+
unsloth_num_chunks = -1,
|
| 542 |
+
unsloth_logit_chunk_multiplier = None,
|
| 543 |
+
unsloth_grpo_mini_batch = None,
|
| 544 |
+
max_seq_length = None,
|
| 545 |
+
**kwargs,
|
| 546 |
+
):
|
| 547 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 548 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 549 |
+
if num_train_epochs is None:
|
| 550 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 551 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 552 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 553 |
+
save_strategy = 'no'
|
| 554 |
+
import multiprocessing as _mp
|
| 555 |
+
if _mp.get_start_method() != 'fork':
|
| 556 |
+
dataset_num_proc = None
|
| 557 |
+
elif dataset_num_proc is None:
|
| 558 |
+
import psutil
|
| 559 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 560 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 561 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 562 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 563 |
+
|
| 564 |
+
super().__init__(
|
| 565 |
+
output_dir = output_dir,
|
| 566 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 567 |
+
num_train_epochs = num_train_epochs,
|
| 568 |
+
max_steps = max_steps,
|
| 569 |
+
learning_rate = learning_rate,
|
| 570 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 571 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 572 |
+
warmup_steps = warmup_steps,
|
| 573 |
+
optim = optim,
|
| 574 |
+
optim_args = optim_args,
|
| 575 |
+
weight_decay = weight_decay,
|
| 576 |
+
adam_beta1 = adam_beta1,
|
| 577 |
+
adam_beta2 = adam_beta2,
|
| 578 |
+
adam_epsilon = adam_epsilon,
|
| 579 |
+
optim_target_modules = optim_target_modules,
|
| 580 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 581 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 582 |
+
max_grad_norm = max_grad_norm,
|
| 583 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 584 |
+
bf16 = bf16,
|
| 585 |
+
fp16 = fp16,
|
| 586 |
+
bf16_full_eval = bf16_full_eval,
|
| 587 |
+
fp16_full_eval = fp16_full_eval,
|
| 588 |
+
tf32 = tf32,
|
| 589 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 590 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 591 |
+
torch_compile = torch_compile,
|
| 592 |
+
torch_compile_backend = torch_compile_backend,
|
| 593 |
+
torch_compile_mode = torch_compile_mode,
|
| 594 |
+
use_liger_kernel = use_liger_kernel,
|
| 595 |
+
liger_kernel_config = liger_kernel_config,
|
| 596 |
+
use_cache = use_cache,
|
| 597 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 598 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 599 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 600 |
+
logging_strategy = logging_strategy,
|
| 601 |
+
logging_steps = logging_steps,
|
| 602 |
+
logging_first_step = logging_first_step,
|
| 603 |
+
log_on_each_node = log_on_each_node,
|
| 604 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 605 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 606 |
+
log_level = log_level,
|
| 607 |
+
log_level_replica = log_level_replica,
|
| 608 |
+
disable_tqdm = disable_tqdm,
|
| 609 |
+
report_to = report_to,
|
| 610 |
+
run_name = run_name,
|
| 611 |
+
project = project,
|
| 612 |
+
trackio_space_id = trackio_space_id,
|
| 613 |
+
eval_strategy = eval_strategy,
|
| 614 |
+
eval_steps = eval_steps,
|
| 615 |
+
eval_delay = eval_delay,
|
| 616 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 617 |
+
prediction_loss_only = prediction_loss_only,
|
| 618 |
+
eval_on_start = eval_on_start,
|
| 619 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 620 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 621 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 622 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 623 |
+
save_only_model = save_only_model,
|
| 624 |
+
save_strategy = save_strategy,
|
| 625 |
+
save_steps = save_steps,
|
| 626 |
+
save_on_each_node = save_on_each_node,
|
| 627 |
+
save_total_limit = save_total_limit,
|
| 628 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 629 |
+
push_to_hub = push_to_hub,
|
| 630 |
+
hub_token = hub_token,
|
| 631 |
+
hub_private_repo = hub_private_repo,
|
| 632 |
+
hub_model_id = hub_model_id,
|
| 633 |
+
hub_strategy = hub_strategy,
|
| 634 |
+
hub_always_push = hub_always_push,
|
| 635 |
+
hub_revision = hub_revision,
|
| 636 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 637 |
+
metric_for_best_model = metric_for_best_model,
|
| 638 |
+
greater_is_better = greater_is_better,
|
| 639 |
+
ignore_data_skip = ignore_data_skip,
|
| 640 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 641 |
+
full_determinism = full_determinism,
|
| 642 |
+
seed = seed,
|
| 643 |
+
data_seed = data_seed,
|
| 644 |
+
use_cpu = use_cpu,
|
| 645 |
+
accelerator_config = accelerator_config,
|
| 646 |
+
parallelism_config = parallelism_config,
|
| 647 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 648 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 649 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 650 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 651 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 652 |
+
remove_unused_columns = remove_unused_columns,
|
| 653 |
+
label_names = label_names,
|
| 654 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 655 |
+
length_column_name = length_column_name,
|
| 656 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 657 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 658 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 659 |
+
ddp_backend = ddp_backend,
|
| 660 |
+
ddp_timeout = ddp_timeout,
|
| 661 |
+
fsdp = fsdp,
|
| 662 |
+
fsdp_config = fsdp_config,
|
| 663 |
+
deepspeed = deepspeed,
|
| 664 |
+
debug = debug,
|
| 665 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 666 |
+
do_train = do_train,
|
| 667 |
+
do_eval = do_eval,
|
| 668 |
+
do_predict = do_predict,
|
| 669 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 670 |
+
warmup_ratio = warmup_ratio,
|
| 671 |
+
logging_dir = logging_dir,
|
| 672 |
+
local_rank = local_rank,
|
| 673 |
+
max_length = max_length,
|
| 674 |
+
max_prompt_length = max_prompt_length,
|
| 675 |
+
max_completion_length = max_completion_length,
|
| 676 |
+
beta = beta,
|
| 677 |
+
label_smoothing = label_smoothing,
|
| 678 |
+
loss_type = loss_type,
|
| 679 |
+
disable_dropout = disable_dropout,
|
| 680 |
+
cpo_alpha = cpo_alpha,
|
| 681 |
+
simpo_gamma = simpo_gamma,
|
| 682 |
+
alpha = alpha,
|
| 683 |
+
label_pad_token_id = label_pad_token_id,
|
| 684 |
+
padding_value = padding_value,
|
| 685 |
+
truncation_mode = truncation_mode,
|
| 686 |
+
generate_during_eval = generate_during_eval,
|
| 687 |
+
is_encoder_decoder = is_encoder_decoder,
|
| 688 |
+
model_init_kwargs = model_init_kwargs,
|
| 689 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 690 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 691 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 692 |
+
if unsloth_grpo_mini_batch is not None:
|
| 693 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 694 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 695 |
+
else:
|
| 696 |
+
raise ValueError(
|
| 697 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 698 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 699 |
+
)
|
| 700 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 701 |
+
self.max_seq_length = max_seq_length
|
| 702 |
+
|
| 703 |
+
pass
|
| 704 |
+
|
| 705 |
+
class _UnslothCPOTrainer(BaseTrainer):
|
| 706 |
+
r""""""
|
| 707 |
+
|
| 708 |
+
_tag_names = ["trl", "cpo"]
|
| 709 |
+
_name = "CPO"
|
| 710 |
+
_paper = {
|
| 711 |
+
"title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
|
| 712 |
+
"id": "2401.08417",
|
| 713 |
+
# docstyle-ignore
|
| 714 |
+
"citation": textwrap.dedent("""\
|
| 715 |
+
@inproceedings{xu2024contrastive,
|
| 716 |
+
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
|
| 717 |
+
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
|
| 718 |
+
year = 2024,
|
| 719 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
| 720 |
+
publisher = {OpenReview.net},
|
| 721 |
+
url = {https://openreview.net/forum?id=51iwkioZpn}
|
| 722 |
+
}"""),
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
def __init__(
|
| 726 |
+
self,
|
| 727 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 728 |
+
args: Optional[CPOConfig] = None,
|
| 729 |
+
data_collator: Optional[DataCollator] = None,
|
| 730 |
+
train_dataset: Optional[Dataset] = None,
|
| 731 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 732 |
+
processing_class: Optional[
|
| 733 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 734 |
+
] = None,
|
| 735 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 736 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 737 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 738 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 739 |
+
peft_config: Optional[dict] = None,
|
| 740 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 741 |
+
):
|
| 742 |
+
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
| 743 |
+
warnings.warn(
|
| 744 |
+
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
| 745 |
+
"it and want it to remain, please share your comments here: "
|
| 746 |
+
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
| 747 |
+
"TRL_EXPERIMENTAL_SILENCE=1."
|
| 748 |
+
)
|
| 749 |
+
if args.model_init_kwargs is None:
|
| 750 |
+
model_init_kwargs = {}
|
| 751 |
+
elif not isinstance(model, str):
|
| 752 |
+
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
| 753 |
+
else:
|
| 754 |
+
model_init_kwargs = args.model_init_kwargs
|
| 755 |
+
dtype = model_init_kwargs.get("dtype")
|
| 756 |
+
if dtype is not None:
|
| 757 |
+
# Convert to `torch.dtype` if an str is passed
|
| 758 |
+
if isinstance(dtype, str) and dtype != "auto":
|
| 759 |
+
dtype = getattr(torch, dtype)
|
| 760 |
+
if dtype != "auto" and not isinstance(dtype, torch.dtype):
|
| 761 |
+
raise ValueError(
|
| 762 |
+
f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
|
| 763 |
+
)
|
| 764 |
+
model_init_kwargs["dtype"] = dtype
|
| 765 |
+
|
| 766 |
+
if isinstance(model, str):
|
| 767 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 768 |
+
|
| 769 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 770 |
+
# has been called in order to properly call autocast if needed.
|
| 771 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 772 |
+
|
| 773 |
+
if not is_peft_available() and peft_config is not None:
|
| 774 |
+
raise ValueError(
|
| 775 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 776 |
+
)
|
| 777 |
+
elif is_peft_available() and peft_config is not None:
|
| 778 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 779 |
+
if isinstance(model, PeftModel):
|
| 780 |
+
model = model.merge_and_unload()
|
| 781 |
+
|
| 782 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 783 |
+
_support_gc_kwargs = hasattr(
|
| 784 |
+
args, "gradient_checkpointing_kwargs"
|
| 785 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 786 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 790 |
+
|
| 791 |
+
if _support_gc_kwargs:
|
| 792 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 793 |
+
|
| 794 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 795 |
+
elif args.gradient_checkpointing:
|
| 796 |
+
# For backward compatibility with older versions of transformers
|
| 797 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 798 |
+
model.enable_input_require_grads()
|
| 799 |
+
else:
|
| 800 |
+
|
| 801 |
+
def make_inputs_require_grad(module, input, output):
|
| 802 |
+
output.requires_grad_(True)
|
| 803 |
+
|
| 804 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 805 |
+
|
| 806 |
+
# get peft model with the given config
|
| 807 |
+
model = model
|
| 808 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 809 |
+
peft_module_casting_to_bf16(model)
|
| 810 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 811 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 812 |
+
|
| 813 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 814 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 815 |
+
# fail or completely fail.
|
| 816 |
+
elif args.gradient_checkpointing:
|
| 817 |
+
# For backward compatibility with older versions of transformers
|
| 818 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 819 |
+
model.enable_input_require_grads()
|
| 820 |
+
else:
|
| 821 |
+
|
| 822 |
+
def make_inputs_require_grad(module, input, output):
|
| 823 |
+
output.requires_grad_(True)
|
| 824 |
+
|
| 825 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 826 |
+
|
| 827 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 828 |
+
raise ValueError(
|
| 829 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 830 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
if model is not None:
|
| 834 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 835 |
+
elif args.is_encoder_decoder is None:
|
| 836 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 837 |
+
else:
|
| 838 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 839 |
+
|
| 840 |
+
if self.is_encoder_decoder:
|
| 841 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
| 842 |
+
self.pad_token_id = model.config.pad_token_id
|
| 843 |
+
|
| 844 |
+
if processing_class is None:
|
| 845 |
+
raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
|
| 846 |
+
if args.max_length is None:
|
| 847 |
+
logger.warning(
|
| 848 |
+
"`max_length` is not set in the CPOConfig's init"
|
| 849 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
| 850 |
+
)
|
| 851 |
+
max_length = 512
|
| 852 |
+
else:
|
| 853 |
+
max_length = args.max_length
|
| 854 |
+
if args.max_prompt_length is None:
|
| 855 |
+
logger.warning(
|
| 856 |
+
"`max_prompt_length` is not set in the CPOConfig's init"
|
| 857 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 858 |
+
)
|
| 859 |
+
max_prompt_length = 128
|
| 860 |
+
else:
|
| 861 |
+
max_prompt_length = args.max_prompt_length
|
| 862 |
+
|
| 863 |
+
if not max_prompt_length < max_length:
|
| 864 |
+
raise ValueError(
|
| 865 |
+
f"max_prompt_length ({max_prompt_length}) should be strictly less than max_length ({max_length})."
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 869 |
+
logger.warning(
|
| 870 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
|
| 871 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 872 |
+
)
|
| 873 |
+
max_completion_length = 128
|
| 874 |
+
else:
|
| 875 |
+
max_completion_length = args.max_completion_length
|
| 876 |
+
|
| 877 |
+
if data_collator is None:
|
| 878 |
+
data_collator = DPODataCollatorWithPadding(
|
| 879 |
+
pad_token_id=processing_class.pad_token_id,
|
| 880 |
+
label_pad_token_id=args.label_pad_token_id,
|
| 881 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
if args.remove_unused_columns:
|
| 885 |
+
args.remove_unused_columns = False
|
| 886 |
+
# warn users
|
| 887 |
+
logger.warning(
|
| 888 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
| 889 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
self.use_dpo_data_collator = True
|
| 893 |
+
else:
|
| 894 |
+
self.use_dpo_data_collator = False
|
| 895 |
+
|
| 896 |
+
# Disable dropout in the model
|
| 897 |
+
if args.disable_dropout:
|
| 898 |
+
disable_dropout_in_model(model)
|
| 899 |
+
|
| 900 |
+
self.max_length = max_length
|
| 901 |
+
self.generate_during_eval = args.generate_during_eval
|
| 902 |
+
self.label_pad_token_id = args.label_pad_token_id
|
| 903 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 904 |
+
self.max_prompt_length = max_prompt_length
|
| 905 |
+
self.truncation_mode = args.truncation_mode
|
| 906 |
+
self.max_completion_length = max_completion_length
|
| 907 |
+
self.processing_class = processing_class
|
| 908 |
+
|
| 909 |
+
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
|
| 910 |
+
logger.warning(
|
| 911 |
+
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
|
| 912 |
+
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
|
| 913 |
+
)
|
| 914 |
+
if args.loss_type == "kto_pair":
|
| 915 |
+
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
|
| 916 |
+
|
| 917 |
+
self.beta = args.beta
|
| 918 |
+
self.label_smoothing = args.label_smoothing
|
| 919 |
+
self.loss_type = args.loss_type
|
| 920 |
+
self.cpo_alpha = args.cpo_alpha
|
| 921 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 922 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 923 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 924 |
+
logger.warning(
|
| 925 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 926 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 927 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 928 |
+
"loss.",
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
if args.loss_type == "simpo":
|
| 932 |
+
self.simpo_gamma = args.simpo_gamma
|
| 933 |
+
|
| 934 |
+
# AlphaPO parameter for reward shaping
|
| 935 |
+
self.alpha = args.alpha
|
| 936 |
+
|
| 937 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 938 |
+
|
| 939 |
+
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 940 |
+
# input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
|
| 941 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
| 942 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 943 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 944 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 945 |
+
# that the warning has already been issued.
|
| 946 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 947 |
+
|
| 948 |
+
# Compute that only on the main process for faster data processing.
|
| 949 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 950 |
+
with PartialState().main_process_first():
|
| 951 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
| 952 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 953 |
+
train_dataset = train_dataset.map(
|
| 954 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 955 |
+
)
|
| 956 |
+
if eval_dataset is not None:
|
| 957 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 958 |
+
eval_dataset = eval_dataset.map(
|
| 959 |
+
maybe_apply_chat_template,
|
| 960 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 961 |
+
num_proc=args.dataset_num_proc,
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
# tokenize the dataset
|
| 965 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 966 |
+
if eval_dataset is not None:
|
| 967 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 968 |
+
|
| 969 |
+
super().__init__(
|
| 970 |
+
model=model,
|
| 971 |
+
args=args,
|
| 972 |
+
data_collator=data_collator,
|
| 973 |
+
train_dataset=train_dataset,
|
| 974 |
+
eval_dataset=eval_dataset,
|
| 975 |
+
processing_class=processing_class,
|
| 976 |
+
model_init=model_init,
|
| 977 |
+
compute_metrics=compute_metrics,
|
| 978 |
+
callbacks=callbacks,
|
| 979 |
+
optimizers=optimizers,
|
| 980 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 984 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 985 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 986 |
+
self.model_accepts_loss_kwargs = False
|
| 987 |
+
|
| 988 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 989 |
+
if hasattr(self.model, "add_model_tags"):
|
| 990 |
+
self.model.add_model_tags(self._tag_names)
|
| 991 |
+
|
| 992 |
+
if not hasattr(self, "accelerator"):
|
| 993 |
+
raise AttributeError(
|
| 994 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
def build_tokenized_answer(self, prompt, answer):
|
| 998 |
+
"""
|
| 999 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
|
| 1000 |
+
b)[len(enc(a)):]`. Reference:
|
| 1001 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 1002 |
+
"""
|
| 1003 |
+
|
| 1004 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
| 1005 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
| 1006 |
+
|
| 1007 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
| 1008 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
| 1009 |
+
|
| 1010 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 1011 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
| 1012 |
+
|
| 1013 |
+
# Prepare input tokens for token by token comparison
|
| 1014 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
| 1015 |
+
|
| 1016 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
| 1017 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
| 1018 |
+
|
| 1019 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 1020 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
| 1021 |
+
# on the last token from the prompt being different when tokenized on its own
|
| 1022 |
+
# vs when done as prompt+answer.
|
| 1023 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
| 1024 |
+
|
| 1025 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 1026 |
+
# last token has changed due to merging.
|
| 1027 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
| 1028 |
+
response_token_ids_start_idx -= 1
|
| 1029 |
+
|
| 1030 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
| 1031 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
| 1032 |
+
|
| 1033 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
| 1034 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 1035 |
+
|
| 1036 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
| 1037 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
| 1038 |
+
|
| 1039 |
+
return dict(
|
| 1040 |
+
prompt_input_ids=prompt_input_ids,
|
| 1041 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 1042 |
+
input_ids=answer_input_ids,
|
| 1043 |
+
attention_mask=answer_attention_mask,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
| 1047 |
+
"""Tokenize a single row from a CPO specific dataset.
|
| 1048 |
+
|
| 1049 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
|
| 1050 |
+
chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
|
| 1051 |
+
we truncate the chosen/rejected.
|
| 1052 |
+
|
| 1053 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
|
| 1054 |
+
of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
|
| 1055 |
+
"""
|
| 1056 |
+
batch = {}
|
| 1057 |
+
prompt = feature["prompt"]
|
| 1058 |
+
chosen = feature["chosen"]
|
| 1059 |
+
rejected = feature["rejected"]
|
| 1060 |
+
|
| 1061 |
+
if not self.is_encoder_decoder:
|
| 1062 |
+
# Check issues below for more details
|
| 1063 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
| 1064 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 1065 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 1066 |
+
|
| 1067 |
+
if not isinstance(prompt, str):
|
| 1068 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 1069 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
| 1070 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
| 1071 |
+
|
| 1072 |
+
if not isinstance(chosen, str):
|
| 1073 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
| 1074 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
| 1075 |
+
|
| 1076 |
+
if not isinstance(rejected, str):
|
| 1077 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
| 1078 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
| 1079 |
+
|
| 1080 |
+
# Last prompt token might get merged by tokenizer and
|
| 1081 |
+
# it should not be included for generation if that happens
|
| 1082 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
| 1083 |
+
|
| 1084 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
| 1085 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
| 1086 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
| 1087 |
+
|
| 1088 |
+
for k, v in prompt_tokens.items():
|
| 1089 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
| 1090 |
+
|
| 1091 |
+
# Make sure prompts only have one different token at most an
|
| 1092 |
+
# and length only differs by 1 at most
|
| 1093 |
+
num_diff_tokens = sum(
|
| 1094 |
+
a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
|
| 1095 |
+
)
|
| 1096 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
| 1097 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
| 1098 |
+
raise ValueError(
|
| 1099 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
| 1100 |
+
"last token due to tokenizer merge ops."
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
| 1104 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
| 1105 |
+
self.processing_class.bos_token_id,
|
| 1106 |
+
prompt_len_input_ids,
|
| 1107 |
+
prompt_tokens,
|
| 1108 |
+
chosen_prompt_len_input_ids,
|
| 1109 |
+
chosen_tokens,
|
| 1110 |
+
rejected_prompt_len_input_ids,
|
| 1111 |
+
rejected_tokens,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
| 1115 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
| 1116 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
| 1120 |
+
|
| 1121 |
+
# if combined sequence is too long, truncate the prompt
|
| 1122 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
| 1123 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 1124 |
+
if self.truncation_mode == "keep_start":
|
| 1125 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 1126 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
| 1127 |
+
elif self.truncation_mode == "keep_end":
|
| 1128 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 1129 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
| 1130 |
+
else:
|
| 1131 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
| 1132 |
+
|
| 1133 |
+
# if that's still too long, truncate the response
|
| 1134 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
| 1135 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 1136 |
+
for k in ["input_ids", "attention_mask"]:
|
| 1137 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
| 1138 |
+
|
| 1139 |
+
# Create labels
|
| 1140 |
+
chosen_sequence_tokens = {
|
| 1141 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 1142 |
+
}
|
| 1143 |
+
rejected_sequence_tokens = {
|
| 1144 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 1145 |
+
}
|
| 1146 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
| 1147 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
| 1148 |
+
self.label_pad_token_id
|
| 1149 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
| 1150 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
| 1151 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
| 1152 |
+
self.label_pad_token_id
|
| 1153 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
| 1154 |
+
|
| 1155 |
+
for k, toks in {
|
| 1156 |
+
"chosen_": chosen_sequence_tokens,
|
| 1157 |
+
"rejected_": rejected_sequence_tokens,
|
| 1158 |
+
"": prompt_tokens,
|
| 1159 |
+
}.items():
|
| 1160 |
+
for type_key, tokens in toks.items():
|
| 1161 |
+
if type_key == "token_type_ids":
|
| 1162 |
+
continue
|
| 1163 |
+
batch[f"{k}{type_key}"] = tokens
|
| 1164 |
+
|
| 1165 |
+
else:
|
| 1166 |
+
chosen_tokens = self.processing_class(
|
| 1167 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 1168 |
+
)
|
| 1169 |
+
rejected_tokens = self.processing_class(
|
| 1170 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 1171 |
+
)
|
| 1172 |
+
prompt_tokens = self.processing_class(
|
| 1173 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
| 1177 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
| 1178 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
| 1179 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
| 1180 |
+
|
| 1181 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
| 1182 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 1183 |
+
labels=torch.tensor(batch["rejected_labels"])
|
| 1184 |
+
)
|
| 1185 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 1186 |
+
labels=torch.tensor(batch["chosen_labels"])
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
return batch
|
| 1190 |
+
|
| 1191 |
+
@staticmethod
|
| 1192 |
+
def concatenated_inputs(
|
| 1193 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1194 |
+
is_encoder_decoder: bool = False,
|
| 1195 |
+
label_pad_token_id: int = -100,
|
| 1196 |
+
padding_value: int = 0,
|
| 1197 |
+
device: Optional[torch.device] = None,
|
| 1198 |
+
) -> dict[str, torch.LongTensor]:
|
| 1199 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
| 1200 |
+
|
| 1201 |
+
Args:
|
| 1202 |
+
batch:
|
| 1203 |
+
A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
|
| 1204 |
+
of shape (batch_size, sequence_length).
|
| 1205 |
+
is_encoder_decoder:
|
| 1206 |
+
Whether the model is an encoder-decoder model.
|
| 1207 |
+
label_pad_token_id:
|
| 1208 |
+
The label pad token id.
|
| 1209 |
+
padding_value:
|
| 1210 |
+
The padding value to use for the concatenated inputs_ids.
|
| 1211 |
+
device:
|
| 1212 |
+
The device for the concatenated inputs.
|
| 1213 |
+
|
| 1214 |
+
Returns:
|
| 1215 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
| 1216 |
+
"""
|
| 1217 |
+
concatenated_batch = {}
|
| 1218 |
+
|
| 1219 |
+
if is_encoder_decoder:
|
| 1220 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
| 1221 |
+
else:
|
| 1222 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
| 1223 |
+
|
| 1224 |
+
for k in batch:
|
| 1225 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
| 1226 |
+
if "labels" in k or is_encoder_decoder:
|
| 1227 |
+
pad_value = label_pad_token_id
|
| 1228 |
+
elif k.endswith("_input_ids"):
|
| 1229 |
+
pad_value = padding_value
|
| 1230 |
+
elif k.endswith("_attention_mask"):
|
| 1231 |
+
pad_value = 0
|
| 1232 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
| 1233 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
| 1234 |
+
for k in batch:
|
| 1235 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
| 1236 |
+
if "labels" in k or is_encoder_decoder:
|
| 1237 |
+
pad_value = label_pad_token_id
|
| 1238 |
+
elif k.endswith("_input_ids"):
|
| 1239 |
+
pad_value = padding_value
|
| 1240 |
+
elif k.endswith("_attention_mask"):
|
| 1241 |
+
pad_value = 0
|
| 1242 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
| 1243 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
| 1244 |
+
(
|
| 1245 |
+
concatenated_batch[concatenated_key],
|
| 1246 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
| 1247 |
+
),
|
| 1248 |
+
dim=0,
|
| 1249 |
+
).to(device=device)
|
| 1250 |
+
|
| 1251 |
+
if is_encoder_decoder:
|
| 1252 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
| 1253 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
| 1254 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
| 1255 |
+
)
|
| 1256 |
+
|
| 1257 |
+
return concatenated_batch
|
| 1258 |
+
|
| 1259 |
+
def cpo_loss(
|
| 1260 |
+
self,
|
| 1261 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 1262 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 1263 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1264 |
+
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
| 1265 |
+
|
| 1266 |
+
Args:
|
| 1267 |
+
policy_chosen_logps:
|
| 1268 |
+
Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
| 1269 |
+
policy_rejected_logps:
|
| 1270 |
+
Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
| 1271 |
+
|
| 1272 |
+
Returns:
|
| 1273 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO
|
| 1274 |
+
loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
|
| 1275 |
+
the chosen and rejected responses, respectively.
|
| 1276 |
+
"""
|
| 1277 |
+
# Apply AlphaPO reward transformation if alpha != 0
|
| 1278 |
+
if self.alpha != 0.0:
|
| 1279 |
+
# Compute probabilities
|
| 1280 |
+
chosen_probs = torch.exp(policy_chosen_logps)
|
| 1281 |
+
rejected_probs = torch.exp(policy_rejected_logps)
|
| 1282 |
+
|
| 1283 |
+
# Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha
|
| 1284 |
+
policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha
|
| 1285 |
+
policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha
|
| 1286 |
+
|
| 1287 |
+
logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device)
|
| 1288 |
+
else:
|
| 1289 |
+
# Standard log probability rewards when alpha = 0
|
| 1290 |
+
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
| 1291 |
+
|
| 1292 |
+
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
| 1293 |
+
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
| 1294 |
+
# calculates a conservative CPO loss.
|
| 1295 |
+
|
| 1296 |
+
if self.loss_type == "simpo":
|
| 1297 |
+
gamma_logratios = self.simpo_gamma / self.beta
|
| 1298 |
+
logits = logits - gamma_logratios
|
| 1299 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 1300 |
+
losses = (
|
| 1301 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 1302 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 1303 |
+
)
|
| 1304 |
+
elif self.loss_type == "sigmoid":
|
| 1305 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 1306 |
+
losses = (
|
| 1307 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 1308 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 1309 |
+
)
|
| 1310 |
+
elif self.loss_type == "hinge":
|
| 1311 |
+
losses = torch.relu(1 - self.beta * logits)
|
| 1312 |
+
elif self.loss_type == "ipo":
|
| 1313 |
+
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
| 1314 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 1315 |
+
else:
|
| 1316 |
+
raise ValueError(
|
| 1317 |
+
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
|
| 1318 |
+
)
|
| 1319 |
+
|
| 1320 |
+
# Calculate rewards for logging
|
| 1321 |
+
if self.alpha != 0.0:
|
| 1322 |
+
# When using AlphaPO transformation, use the transformed rewards
|
| 1323 |
+
chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach()
|
| 1324 |
+
rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach()
|
| 1325 |
+
else:
|
| 1326 |
+
# Standard log probability rewards
|
| 1327 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
| 1328 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
| 1329 |
+
|
| 1330 |
+
return losses, chosen_rewards, rejected_rewards
|
| 1331 |
+
|
| 1332 |
+
@staticmethod
|
| 1333 |
+
def get_batch_logps(
|
| 1334 |
+
logits: torch.FloatTensor,
|
| 1335 |
+
labels: torch.LongTensor,
|
| 1336 |
+
average_log_prob: bool = False,
|
| 1337 |
+
label_pad_token_id: int = -100,
|
| 1338 |
+
is_encoder_decoder: bool = False,
|
| 1339 |
+
) -> torch.FloatTensor:
|
| 1340 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 1341 |
+
|
| 1342 |
+
Args:
|
| 1343 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1344 |
+
labels:
|
| 1345 |
+
Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
|
| 1346 |
+
ignored. Shape: (batch_size, sequence_length)
|
| 1347 |
+
average_log_prob:
|
| 1348 |
+
If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
|
| 1349 |
+
log probabilities of the (non-masked) tokens.
|
| 1350 |
+
label_pad_token_id: The label pad token id.
|
| 1351 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 1352 |
+
|
| 1353 |
+
Returns:
|
| 1354 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
|
| 1355 |
+
given logits.
|
| 1356 |
+
"""
|
| 1357 |
+
if logits.shape[:-1] != labels.shape:
|
| 1358 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1359 |
+
|
| 1360 |
+
if not is_encoder_decoder:
|
| 1361 |
+
labels = labels[:, 1:].clone()
|
| 1362 |
+
logits = logits[:, :-1, :]
|
| 1363 |
+
loss_mask = labels != label_pad_token_id
|
| 1364 |
+
|
| 1365 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 1366 |
+
labels[labels == label_pad_token_id] = 0
|
| 1367 |
+
|
| 1368 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 1369 |
+
|
| 1370 |
+
if average_log_prob:
|
| 1371 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1372 |
+
else:
|
| 1373 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 1374 |
+
|
| 1375 |
+
def concatenated_forward(
|
| 1376 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1377 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1378 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
| 1379 |
+
|
| 1380 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
| 1381 |
+
"""
|
| 1382 |
+
concatenated_batch = self.concatenated_inputs(
|
| 1383 |
+
batch,
|
| 1384 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1385 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1386 |
+
padding_value=self.padding_value,
|
| 1387 |
+
device=self.accelerator.device,
|
| 1388 |
+
)
|
| 1389 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
| 1390 |
+
|
| 1391 |
+
model_kwargs = (
|
| 1392 |
+
{
|
| 1393 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
| 1394 |
+
}
|
| 1395 |
+
if self.is_encoder_decoder
|
| 1396 |
+
else {}
|
| 1397 |
+
)
|
| 1398 |
+
|
| 1399 |
+
if self.aux_loss_enabled:
|
| 1400 |
+
model_kwargs["output_router_logits"] = True
|
| 1401 |
+
|
| 1402 |
+
outputs = model(
|
| 1403 |
+
concatenated_batch["concatenated_input_ids"],
|
| 1404 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
| 1405 |
+
use_cache=False,
|
| 1406 |
+
**model_kwargs,
|
| 1407 |
+
)
|
| 1408 |
+
all_logits = outputs.logits
|
| 1409 |
+
|
| 1410 |
+
def cross_entropy_loss(logits, labels):
|
| 1411 |
+
if not self.is_encoder_decoder:
|
| 1412 |
+
# Shift so that tokens < n predict n
|
| 1413 |
+
logits = logits[..., :-1, :].contiguous()
|
| 1414 |
+
labels = labels[..., 1:].contiguous()
|
| 1415 |
+
# Flatten the tokens
|
| 1416 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1417 |
+
logits = logits.view(-1, logits.shape[-1])
|
| 1418 |
+
labels = labels.view(-1)
|
| 1419 |
+
# Enable model parallelism
|
| 1420 |
+
labels = labels.to(logits.device)
|
| 1421 |
+
loss = loss_fct(logits, labels)
|
| 1422 |
+
return loss
|
| 1423 |
+
|
| 1424 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
| 1425 |
+
|
| 1426 |
+
if self.cpo_alpha == 0:
|
| 1427 |
+
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
|
| 1428 |
+
else:
|
| 1429 |
+
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
| 1430 |
+
|
| 1431 |
+
all_logps = self.get_batch_logps(
|
| 1432 |
+
all_logits,
|
| 1433 |
+
concatenated_batch["concatenated_labels"],
|
| 1434 |
+
average_log_prob=self.loss_type in ["ipo", "simpo"],
|
| 1435 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1436 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1437 |
+
)
|
| 1438 |
+
|
| 1439 |
+
chosen_logps = all_logps[:len_chosen]
|
| 1440 |
+
rejected_logps = all_logps[len_chosen:]
|
| 1441 |
+
|
| 1442 |
+
chosen_logits = all_logits[:len_chosen]
|
| 1443 |
+
rejected_logits = all_logits[len_chosen:]
|
| 1444 |
+
|
| 1445 |
+
if self.aux_loss_enabled:
|
| 1446 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
|
| 1447 |
+
|
| 1448 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
| 1449 |
+
|
| 1450 |
+
def get_batch_loss_metrics(
|
| 1451 |
+
self,
|
| 1452 |
+
model,
|
| 1453 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1454 |
+
train_eval: Literal["train", "eval"] = "train",
|
| 1455 |
+
):
|
| 1456 |
+
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
| 1457 |
+
metrics = {}
|
| 1458 |
+
|
| 1459 |
+
forward_output = self.concatenated_forward(model, batch)
|
| 1460 |
+
(
|
| 1461 |
+
policy_chosen_logps,
|
| 1462 |
+
policy_rejected_logps,
|
| 1463 |
+
policy_chosen_logits,
|
| 1464 |
+
policy_rejected_logits,
|
| 1465 |
+
policy_nll_loss,
|
| 1466 |
+
) = forward_output[:5]
|
| 1467 |
+
if self.aux_loss_enabled:
|
| 1468 |
+
aux_loss = forward_output[5]
|
| 1469 |
+
|
| 1470 |
+
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
| 1471 |
+
policy_chosen_logps,
|
| 1472 |
+
policy_rejected_logps,
|
| 1473 |
+
)
|
| 1474 |
+
|
| 1475 |
+
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
| 1476 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
| 1477 |
+
|
| 1478 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1479 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
| 1480 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
| 1481 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
| 1482 |
+
metrics[f"{prefix}rewards/margins"] = (
|
| 1483 |
+
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
| 1484 |
+
)
|
| 1485 |
+
metrics[f"{prefix}logps/rejected"] = (
|
| 1486 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
|
| 1487 |
+
)
|
| 1488 |
+
metrics[f"{prefix}logps/chosen"] = (
|
| 1489 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
|
| 1490 |
+
)
|
| 1491 |
+
metrics[f"{prefix}logits/rejected"] = (
|
| 1492 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item()
|
| 1493 |
+
)
|
| 1494 |
+
metrics[f"{prefix}logits/chosen"] = (
|
| 1495 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item()
|
| 1496 |
+
)
|
| 1497 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
| 1498 |
+
|
| 1499 |
+
if self.aux_loss_enabled:
|
| 1500 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1501 |
+
|
| 1502 |
+
return loss, metrics
|
| 1503 |
+
|
| 1504 |
+
def compute_loss(
|
| 1505 |
+
self,
|
| 1506 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1507 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1508 |
+
return_outputs=False,
|
| 1509 |
+
num_items_in_batch=None,
|
| 1510 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1511 |
+
compute_loss_context_manager = (
|
| 1512 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1513 |
+
)
|
| 1514 |
+
|
| 1515 |
+
with compute_loss_context_manager:
|
| 1516 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
| 1517 |
+
|
| 1518 |
+
# force log the metrics
|
| 1519 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1520 |
+
|
| 1521 |
+
if return_outputs:
|
| 1522 |
+
return (loss, metrics)
|
| 1523 |
+
return loss
|
| 1524 |
+
|
| 1525 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
| 1526 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1527 |
+
|
| 1528 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1529 |
+
# the torch amp context manager as some hidden states are silently casted to full precision.
|
| 1530 |
+
generate_context_manager = (
|
| 1531 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1532 |
+
)
|
| 1533 |
+
|
| 1534 |
+
with generate_context_manager:
|
| 1535 |
+
policy_output = model.generate(
|
| 1536 |
+
input_ids=batch["prompt_input_ids"],
|
| 1537 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1538 |
+
max_length=self.max_length,
|
| 1539 |
+
do_sample=True,
|
| 1540 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1541 |
+
)
|
| 1542 |
+
|
| 1543 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1544 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1545 |
+
|
| 1546 |
+
return policy_output_decoded
|
| 1547 |
+
|
| 1548 |
+
def prediction_step(
|
| 1549 |
+
self,
|
| 1550 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1551 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1552 |
+
prediction_loss_only: bool,
|
| 1553 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1554 |
+
):
|
| 1555 |
+
if ignore_keys is None:
|
| 1556 |
+
if hasattr(model, "config"):
|
| 1557 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1558 |
+
else:
|
| 1559 |
+
ignore_keys = []
|
| 1560 |
+
|
| 1561 |
+
prediction_context_manager = (
|
| 1562 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1563 |
+
)
|
| 1564 |
+
|
| 1565 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1566 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
| 1567 |
+
|
| 1568 |
+
# force log the metrics
|
| 1569 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1570 |
+
|
| 1571 |
+
if prediction_loss_only:
|
| 1572 |
+
return (loss.detach(), None, None)
|
| 1573 |
+
|
| 1574 |
+
# logits for the chosen and rejected samples from model
|
| 1575 |
+
logits_dict = {
|
| 1576 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
| 1577 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
| 1578 |
+
}
|
| 1579 |
+
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 1580 |
+
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 1581 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1582 |
+
|
| 1583 |
+
return (loss.detach(), logits, labels)
|
| 1584 |
+
|
| 1585 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1586 |
+
for key, value in metrics.items():
|
| 1587 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1588 |
+
|
| 1589 |
+
def evaluation_loop(
|
| 1590 |
+
self,
|
| 1591 |
+
dataloader: DataLoader,
|
| 1592 |
+
description: str,
|
| 1593 |
+
prediction_loss_only: Optional[bool] = None,
|
| 1594 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1595 |
+
metric_key_prefix: str = "eval",
|
| 1596 |
+
) -> EvalLoopOutput:
|
| 1597 |
+
"""
|
| 1598 |
+
Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
|
| 1599 |
+
`Trainer.evaluate()` and `Trainer.predict()`.
|
| 1600 |
+
|
| 1601 |
+
Works both with or without labels.
|
| 1602 |
+
"""
|
| 1603 |
+
|
| 1604 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1605 |
+
if self.generate_during_eval:
|
| 1606 |
+
# Generate random indices within the range of the total number of samples
|
| 1607 |
+
num_samples = len(dataloader.dataset)
|
| 1608 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1609 |
+
|
| 1610 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1611 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1612 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1613 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1614 |
+
|
| 1615 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
| 1616 |
+
|
| 1617 |
+
table = pd.DataFrame(
|
| 1618 |
+
columns=["Prompt", "Policy"],
|
| 1619 |
+
data=[
|
| 1620 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
| 1621 |
+
],
|
| 1622 |
+
)
|
| 1623 |
+
if "wandb" in self.args.report_to:
|
| 1624 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1625 |
+
|
| 1626 |
+
if "comet_ml" in self.args.report_to:
|
| 1627 |
+
log_table_to_comet_experiment(
|
| 1628 |
+
name="game_log.csv",
|
| 1629 |
+
table=table,
|
| 1630 |
+
)
|
| 1631 |
+
|
| 1632 |
+
# Base evaluation
|
| 1633 |
+
initial_output = super().evaluation_loop(
|
| 1634 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1635 |
+
)
|
| 1636 |
+
|
| 1637 |
+
return initial_output
|
| 1638 |
+
|
| 1639 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1640 |
+
"""
|
| 1641 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1642 |
+
|
| 1643 |
+
Args:
|
| 1644 |
+
logs (`dict[str, float]`):
|
| 1645 |
+
The values to log.
|
| 1646 |
+
start_time (`float`, *optional*):
|
| 1647 |
+
Start time of the training.
|
| 1648 |
+
"""
|
| 1649 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1650 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1651 |
+
# Add averaged stored metrics to logs
|
| 1652 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1653 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
| 1654 |
+
del self._stored_metrics[train_eval]
|
| 1655 |
+
return super().log(logs, start_time)
|
| 1656 |
+
|
| 1657 |
+
def _shift_right(self, input_ids):
|
| 1658 |
+
if self.decoder_start_token_id is None:
|
| 1659 |
+
raise ValueError(
|
| 1660 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
| 1661 |
+
)
|
| 1662 |
+
|
| 1663 |
+
# shift inputs to the right
|
| 1664 |
+
if is_torch_fx_proxy(input_ids):
|
| 1665 |
+
# Item assignment is not supported natively for proxies.
|
| 1666 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
| 1667 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 1668 |
+
else:
|
| 1669 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 1670 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 1671 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
| 1672 |
+
|
| 1673 |
+
if self.pad_token_id is None:
|
| 1674 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
| 1675 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 1676 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 1677 |
+
|
| 1678 |
+
return shifted_input_ids
|
| 1679 |
+
|
| 1680 |
+
# Ensure the model card is saved along with the checkpoint
|
| 1681 |
+
def _save_checkpoint(self, model, trial):
|
| 1682 |
+
if self.args.hub_model_id is None:
|
| 1683 |
+
model_name = Path(self.args.output_dir).name
|
| 1684 |
+
else:
|
| 1685 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 1686 |
+
self.create_model_card(model_name=model_name)
|
| 1687 |
+
super()._save_checkpoint(model, trial)
|
| 1688 |
+
class UnslothCPOTrainer(_UnslothCPOTrainer):
|
| 1689 |
+
"""
|
| 1690 |
+
|
| 1691 |
+
Initialize CPOTrainer.
|
| 1692 |
+
|
| 1693 |
+
Args:
|
| 1694 |
+
model ([`~transformers.PreTrainedModel`]):
|
| 1695 |
+
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
| 1696 |
+
args ([`CPOConfig`]):
|
| 1697 |
+
The CPO config arguments to use for training.
|
| 1698 |
+
data_collator ([`~transformers.DataCollator`]):
|
| 1699 |
+
The data collator to use for training. If None is specified, the default data collator
|
| 1700 |
+
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
| 1701 |
+
sequences in the batch, given a dataset of paired sequences.
|
| 1702 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 1703 |
+
The dataset to use for training.
|
| 1704 |
+
eval_dataset ([`~datasets.Dataset`]):
|
| 1705 |
+
The dataset to use for evaluation.
|
| 1706 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 1707 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1708 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1709 |
+
reuse the fine-tuned model.
|
| 1710 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1711 |
+
The model initializer to use for training. If None is specified, the default model initializer will be
|
| 1712 |
+
used.
|
| 1713 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1714 |
+
The callbacks to use for training.
|
| 1715 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1716 |
+
The optimizer and scheduler to use for training.
|
| 1717 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1718 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1719 |
+
peft_config (`dict`, defaults to `None`):
|
| 1720 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
|
| 1721 |
+
a PEFT model.
|
| 1722 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1723 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
|
| 1724 |
+
metric values.
|
| 1725 |
+
|
| 1726 |
+
"""
|
| 1727 |
+
def __init__(
|
| 1728 |
+
self,
|
| 1729 |
+
model = None,
|
| 1730 |
+
args = None,
|
| 1731 |
+
data_collator = None,
|
| 1732 |
+
train_dataset = None,
|
| 1733 |
+
eval_dataset = None,
|
| 1734 |
+
processing_class = None,
|
| 1735 |
+
model_init = None,
|
| 1736 |
+
callbacks = None,
|
| 1737 |
+
preprocess_logits_for_metrics = None,
|
| 1738 |
+
peft_config = None,
|
| 1739 |
+
compute_metrics = None,
|
| 1740 |
+
**kwargs
|
| 1741 |
+
):
|
| 1742 |
+
if args is None: args = UnslothCPOConfig()
|
| 1743 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1744 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1745 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1746 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1747 |
+
force_float32 = False
|
| 1748 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1749 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1750 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1751 |
+
force_float32 = True
|
| 1752 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1753 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1754 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1755 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1756 |
+
dtype = _get_dtype(dtype)
|
| 1757 |
+
float16 = dtype == torch.float16
|
| 1758 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1759 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1760 |
+
if force_float32:
|
| 1761 |
+
# Forced float32 training
|
| 1762 |
+
args.fp16 = False
|
| 1763 |
+
args.bf16 = False
|
| 1764 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1765 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1766 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1767 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1768 |
+
# Mixed precision training
|
| 1769 |
+
args.fp16 = float16
|
| 1770 |
+
args.bf16 = not float16
|
| 1771 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1772 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1773 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1774 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1775 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1776 |
+
args.fp16 = False
|
| 1777 |
+
args.bf16 = False
|
| 1778 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1779 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1780 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1781 |
+
|
| 1782 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1783 |
+
args.eval_strategy = 'steps'
|
| 1784 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1785 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1786 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1787 |
+
from transformers import __version__ as transformers_version
|
| 1788 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1789 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1790 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1791 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1792 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1793 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1794 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1795 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1796 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1797 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1798 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1799 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1800 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1801 |
+
if force_float32:
|
| 1802 |
+
args.bf16_full_eval = False
|
| 1803 |
+
args.fp16_full_eval = False
|
| 1804 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1805 |
+
args.bf16_full_eval = True
|
| 1806 |
+
args.fp16_full_eval = False
|
| 1807 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1808 |
+
args.bf16_full_eval = args.bf16
|
| 1809 |
+
args.fp16_full_eval = args.fp16
|
| 1810 |
+
_output_logits = False
|
| 1811 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1812 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1813 |
+
if _output_logits:
|
| 1814 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1815 |
+
if model is not None:
|
| 1816 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1817 |
+
if _warnings_issued is None:
|
| 1818 |
+
model.warnings_issued = {}
|
| 1819 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1820 |
+
try:
|
| 1821 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1822 |
+
except Exception:
|
| 1823 |
+
model.warnings_issued = {}
|
| 1824 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1825 |
+
pass
|
| 1826 |
+
else:
|
| 1827 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1828 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1829 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1830 |
+
max_seq_length = model.max_seq_length
|
| 1831 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1832 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1833 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1834 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1835 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1836 |
+
args.max_seq_length = model_max_seq_length
|
| 1837 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1838 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1839 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1840 |
+
if 'processing_class' in locals():
|
| 1841 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1842 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1843 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1844 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1845 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1846 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1847 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1848 |
+
__tokenizer,
|
| 1849 |
+
mlm = False,
|
| 1850 |
+
mlm_probability = 0.0,
|
| 1851 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1852 |
+
)
|
| 1853 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1854 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1855 |
+
__tokenizer,
|
| 1856 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1857 |
+
)
|
| 1858 |
+
else:
|
| 1859 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1860 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1861 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1862 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1863 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1864 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1865 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1866 |
+
__tokenizer.tokenizer,
|
| 1867 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1868 |
+
)
|
| 1869 |
+
else:
|
| 1870 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1871 |
+
__tokenizer.tokenizer,
|
| 1872 |
+
mlm = False,
|
| 1873 |
+
mlm_probability = 0.0,
|
| 1874 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1875 |
+
)
|
| 1876 |
+
other_metrics = []
|
| 1877 |
+
|
| 1878 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1879 |
+
PatchRLStatistics('cpo_trainer', other_metrics)
|
| 1880 |
+
|
| 1881 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1882 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1883 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1884 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1885 |
+
args._n_gpu = 1
|
| 1886 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1887 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1888 |
+
super().__init__(
|
| 1889 |
+
model = model,
|
| 1890 |
+
args = args,
|
| 1891 |
+
data_collator = data_collator,
|
| 1892 |
+
train_dataset = train_dataset,
|
| 1893 |
+
eval_dataset = eval_dataset,
|
| 1894 |
+
processing_class = processing_class,
|
| 1895 |
+
model_init = model_init,
|
| 1896 |
+
callbacks = callbacks,
|
| 1897 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1898 |
+
peft_config = peft_config,
|
| 1899 |
+
compute_metrics = compute_metrics,**kwargs)
|
| 1900 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1901 |
+
model.for_inference()
|
| 1902 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1903 |
+
self.neftune_hook_handle.remove()
|
| 1904 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1905 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1906 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1907 |
+
pass
|
| 1908 |
+
if hasattr(self, 'accelerator'):
|
| 1909 |
+
scaler = self.accelerator.scaler
|
| 1910 |
+
current_model = model
|
| 1911 |
+
while hasattr(current_model, 'model'):
|
| 1912 |
+
current_model.accelerator_scaler = scaler
|
| 1913 |
+
current_model = current_model.model
|
| 1914 |
+
current_model.accelerator_scaler = scaler
|
| 1915 |
+
pass
|
| 1916 |
+
if hasattr(self, 'train'):
|
| 1917 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1918 |
+
pass
|
| 1919 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1920 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1921 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1922 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1923 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1924 |
+
pass
|
| 1925 |
+
|
| 1926 |
+
pass
|
| 1927 |
+
|
| 1928 |
+
|
| 1929 |
+
if hasattr(logger, "addFilter"):
|
| 1930 |
+
import logging
|
| 1931 |
+
class HideLoggingMessage(logging.Filter):
|
| 1932 |
+
def __init__(self, text): self.text = text
|
| 1933 |
+
def filter(self, x): return not (self.text in x.getMessage())
|
| 1934 |
+
pass
|
| 1935 |
+
logger.addFilter(HideLoggingMessage("`use_cache=True`"))
|
| 1936 |
+
|
unsloth_compiled_cache/UnslothGKDTrainer.py
ADDED
|
@@ -0,0 +1,1287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, nn, os, prepare_deepspeed, torch, warnings)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothGKDConfig(GKDConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for [`GKDTrainer`].
|
| 328 |
+
|
| 329 |
+
This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
|
| 330 |
+
please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
| 334 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 335 |
+
lmbda (`float`, *optional*, defaults to `0.5`):
|
| 336 |
+
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
| 337 |
+
student-generated outputs).
|
| 338 |
+
beta (`float`, *optional*, defaults to `0.5`):
|
| 339 |
+
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
| 340 |
+
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
| 341 |
+
max_new_tokens (`int`, *optional*, defaults to `128`):
|
| 342 |
+
Maximum number of tokens to generate per completion.
|
| 343 |
+
teacher_model_name_or_path (`str`, *optional*):
|
| 344 |
+
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
|
| 345 |
+
trained.
|
| 346 |
+
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
|
| 347 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
| 348 |
+
from a string.
|
| 349 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 350 |
+
Whether to disable dropout in the model.
|
| 351 |
+
seq_kd (`bool`, *optional*, defaults to `False`):
|
| 352 |
+
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
|
| 353 |
+
teacher-generated output).
|
| 354 |
+
|
| 355 |
+
"""
|
| 356 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 357 |
+
default = None,
|
| 358 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 359 |
+
)
|
| 360 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 361 |
+
default = -1,
|
| 362 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 363 |
+
)
|
| 364 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 365 |
+
default = None,
|
| 366 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 367 |
+
)
|
| 368 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 369 |
+
default = None,
|
| 370 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 371 |
+
)
|
| 372 |
+
max_seq_length : Optional[int] = field(
|
| 373 |
+
default = None,
|
| 374 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 375 |
+
)
|
| 376 |
+
def __init__(
|
| 377 |
+
self,
|
| 378 |
+
output_dir = None,
|
| 379 |
+
per_device_train_batch_size = 4,
|
| 380 |
+
num_train_epochs = 3.0,
|
| 381 |
+
max_steps = -1,
|
| 382 |
+
learning_rate = 5e-05,
|
| 383 |
+
lr_scheduler_type = 'linear',
|
| 384 |
+
lr_scheduler_kwargs = None,
|
| 385 |
+
warmup_steps = 0.1,
|
| 386 |
+
optim = 'adamw_8bit',
|
| 387 |
+
optim_args = None,
|
| 388 |
+
weight_decay = 0.01,
|
| 389 |
+
adam_beta1 = 0.9,
|
| 390 |
+
adam_beta2 = 0.999,
|
| 391 |
+
adam_epsilon = 1e-08,
|
| 392 |
+
optim_target_modules = None,
|
| 393 |
+
gradient_accumulation_steps = 2,
|
| 394 |
+
average_tokens_across_devices = True,
|
| 395 |
+
max_grad_norm = 1.0,
|
| 396 |
+
label_smoothing_factor = 0.0,
|
| 397 |
+
bf16 = False,
|
| 398 |
+
fp16 = False,
|
| 399 |
+
bf16_full_eval = False,
|
| 400 |
+
fp16_full_eval = False,
|
| 401 |
+
tf32 = None,
|
| 402 |
+
gradient_checkpointing = True,
|
| 403 |
+
gradient_checkpointing_kwargs = None,
|
| 404 |
+
torch_compile = False,
|
| 405 |
+
torch_compile_backend = None,
|
| 406 |
+
torch_compile_mode = None,
|
| 407 |
+
use_liger_kernel = False,
|
| 408 |
+
liger_kernel_config = None,
|
| 409 |
+
use_cache = False,
|
| 410 |
+
neftune_noise_alpha = None,
|
| 411 |
+
torch_empty_cache_steps = 250,
|
| 412 |
+
auto_find_batch_size = False,
|
| 413 |
+
logging_strategy = 'steps',
|
| 414 |
+
logging_steps = 1,
|
| 415 |
+
logging_first_step = False,
|
| 416 |
+
log_on_each_node = True,
|
| 417 |
+
logging_nan_inf_filter = False,
|
| 418 |
+
include_num_input_tokens_seen = False,
|
| 419 |
+
log_level = 'passive',
|
| 420 |
+
log_level_replica = 'warning',
|
| 421 |
+
disable_tqdm = None,
|
| 422 |
+
report_to = 'none',
|
| 423 |
+
run_name = None,
|
| 424 |
+
project = 'huggingface',
|
| 425 |
+
trackio_space_id = 'trackio',
|
| 426 |
+
eval_strategy = 'no',
|
| 427 |
+
eval_steps = None,
|
| 428 |
+
eval_delay = 0,
|
| 429 |
+
per_device_eval_batch_size = 4,
|
| 430 |
+
prediction_loss_only = False,
|
| 431 |
+
eval_on_start = False,
|
| 432 |
+
eval_do_concat_batches = True,
|
| 433 |
+
eval_use_gather_object = False,
|
| 434 |
+
eval_accumulation_steps = 2,
|
| 435 |
+
batch_eval_metrics = False,
|
| 436 |
+
save_only_model = False,
|
| 437 |
+
save_strategy = 'steps',
|
| 438 |
+
save_steps = 500,
|
| 439 |
+
save_on_each_node = False,
|
| 440 |
+
save_total_limit = None,
|
| 441 |
+
enable_jit_checkpoint = False,
|
| 442 |
+
push_to_hub = False,
|
| 443 |
+
hub_token = None,
|
| 444 |
+
hub_private_repo = None,
|
| 445 |
+
hub_model_id = None,
|
| 446 |
+
hub_strategy = 'every_save',
|
| 447 |
+
hub_always_push = False,
|
| 448 |
+
hub_revision = None,
|
| 449 |
+
load_best_model_at_end = False,
|
| 450 |
+
metric_for_best_model = None,
|
| 451 |
+
greater_is_better = None,
|
| 452 |
+
ignore_data_skip = False,
|
| 453 |
+
restore_callback_states_from_checkpoint = False,
|
| 454 |
+
full_determinism = False,
|
| 455 |
+
seed = 3407,
|
| 456 |
+
data_seed = 3407,
|
| 457 |
+
use_cpu = False,
|
| 458 |
+
accelerator_config = None,
|
| 459 |
+
parallelism_config = None,
|
| 460 |
+
dataloader_drop_last = False,
|
| 461 |
+
dataloader_num_workers = 0,
|
| 462 |
+
dataloader_pin_memory = True,
|
| 463 |
+
dataloader_persistent_workers = False,
|
| 464 |
+
dataloader_prefetch_factor = None,
|
| 465 |
+
remove_unused_columns = True,
|
| 466 |
+
label_names = None,
|
| 467 |
+
train_sampling_strategy = 'random',
|
| 468 |
+
length_column_name = 'length',
|
| 469 |
+
ddp_find_unused_parameters = None,
|
| 470 |
+
ddp_bucket_cap_mb = None,
|
| 471 |
+
ddp_broadcast_buffers = None,
|
| 472 |
+
ddp_backend = None,
|
| 473 |
+
ddp_timeout = 1800,
|
| 474 |
+
fsdp = None,
|
| 475 |
+
fsdp_config = None,
|
| 476 |
+
deepspeed = None,
|
| 477 |
+
debug = '',
|
| 478 |
+
skip_memory_metrics = True,
|
| 479 |
+
do_train = False,
|
| 480 |
+
do_eval = False,
|
| 481 |
+
do_predict = False,
|
| 482 |
+
resume_from_checkpoint = None,
|
| 483 |
+
warmup_ratio = None,
|
| 484 |
+
logging_dir = None,
|
| 485 |
+
local_rank = -1,
|
| 486 |
+
model_init_kwargs = None,
|
| 487 |
+
chat_template_path = None,
|
| 488 |
+
dataset_text_field = 'text',
|
| 489 |
+
dataset_kwargs = None,
|
| 490 |
+
dataset_num_proc = None,
|
| 491 |
+
eos_token = None,
|
| 492 |
+
pad_token = None,
|
| 493 |
+
max_length = 1024,
|
| 494 |
+
packing = False,
|
| 495 |
+
packing_strategy = 'bfd',
|
| 496 |
+
padding_free = None,
|
| 497 |
+
pad_to_multiple_of = None,
|
| 498 |
+
eval_packing = None,
|
| 499 |
+
completion_only_loss = None,
|
| 500 |
+
assistant_only_loss = False,
|
| 501 |
+
loss_type = 'nll',
|
| 502 |
+
activation_offloading = False,
|
| 503 |
+
temperature = 0.9,
|
| 504 |
+
lmbda = 0.5,
|
| 505 |
+
beta = 0.5,
|
| 506 |
+
max_new_tokens = 128,
|
| 507 |
+
teacher_model_name_or_path = None,
|
| 508 |
+
teacher_model_init_kwargs = None,
|
| 509 |
+
disable_dropout = True,
|
| 510 |
+
seq_kd = False,
|
| 511 |
+
vllm_sampling_params = None,
|
| 512 |
+
unsloth_num_chunks = -1,
|
| 513 |
+
unsloth_logit_chunk_multiplier = None,
|
| 514 |
+
unsloth_grpo_mini_batch = None,
|
| 515 |
+
max_seq_length = None,
|
| 516 |
+
**kwargs,
|
| 517 |
+
):
|
| 518 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 519 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 520 |
+
if num_train_epochs is None:
|
| 521 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 522 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 523 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 524 |
+
save_strategy = 'no'
|
| 525 |
+
import multiprocessing as _mp
|
| 526 |
+
if _mp.get_start_method() != 'fork':
|
| 527 |
+
dataset_num_proc = None
|
| 528 |
+
elif dataset_num_proc is None:
|
| 529 |
+
import psutil
|
| 530 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 531 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 532 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 533 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 534 |
+
if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
|
| 535 |
+
from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
|
| 536 |
+
if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
|
| 537 |
+
from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
|
| 538 |
+
pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
|
| 539 |
+
|
| 540 |
+
if temperature <= 0:
|
| 541 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 542 |
+
elif temperature >= 10:
|
| 543 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
super().__init__(
|
| 547 |
+
output_dir = output_dir,
|
| 548 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 549 |
+
num_train_epochs = num_train_epochs,
|
| 550 |
+
max_steps = max_steps,
|
| 551 |
+
learning_rate = learning_rate,
|
| 552 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 553 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 554 |
+
warmup_steps = warmup_steps,
|
| 555 |
+
optim = optim,
|
| 556 |
+
optim_args = optim_args,
|
| 557 |
+
weight_decay = weight_decay,
|
| 558 |
+
adam_beta1 = adam_beta1,
|
| 559 |
+
adam_beta2 = adam_beta2,
|
| 560 |
+
adam_epsilon = adam_epsilon,
|
| 561 |
+
optim_target_modules = optim_target_modules,
|
| 562 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 563 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 564 |
+
max_grad_norm = max_grad_norm,
|
| 565 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 566 |
+
bf16 = bf16,
|
| 567 |
+
fp16 = fp16,
|
| 568 |
+
bf16_full_eval = bf16_full_eval,
|
| 569 |
+
fp16_full_eval = fp16_full_eval,
|
| 570 |
+
tf32 = tf32,
|
| 571 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 572 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 573 |
+
torch_compile = torch_compile,
|
| 574 |
+
torch_compile_backend = torch_compile_backend,
|
| 575 |
+
torch_compile_mode = torch_compile_mode,
|
| 576 |
+
use_liger_kernel = use_liger_kernel,
|
| 577 |
+
liger_kernel_config = liger_kernel_config,
|
| 578 |
+
use_cache = use_cache,
|
| 579 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 580 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 581 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 582 |
+
logging_strategy = logging_strategy,
|
| 583 |
+
logging_steps = logging_steps,
|
| 584 |
+
logging_first_step = logging_first_step,
|
| 585 |
+
log_on_each_node = log_on_each_node,
|
| 586 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 587 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 588 |
+
log_level = log_level,
|
| 589 |
+
log_level_replica = log_level_replica,
|
| 590 |
+
disable_tqdm = disable_tqdm,
|
| 591 |
+
report_to = report_to,
|
| 592 |
+
run_name = run_name,
|
| 593 |
+
project = project,
|
| 594 |
+
trackio_space_id = trackio_space_id,
|
| 595 |
+
eval_strategy = eval_strategy,
|
| 596 |
+
eval_steps = eval_steps,
|
| 597 |
+
eval_delay = eval_delay,
|
| 598 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 599 |
+
prediction_loss_only = prediction_loss_only,
|
| 600 |
+
eval_on_start = eval_on_start,
|
| 601 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 602 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 603 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 604 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 605 |
+
save_only_model = save_only_model,
|
| 606 |
+
save_strategy = save_strategy,
|
| 607 |
+
save_steps = save_steps,
|
| 608 |
+
save_on_each_node = save_on_each_node,
|
| 609 |
+
save_total_limit = save_total_limit,
|
| 610 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 611 |
+
push_to_hub = push_to_hub,
|
| 612 |
+
hub_token = hub_token,
|
| 613 |
+
hub_private_repo = hub_private_repo,
|
| 614 |
+
hub_model_id = hub_model_id,
|
| 615 |
+
hub_strategy = hub_strategy,
|
| 616 |
+
hub_always_push = hub_always_push,
|
| 617 |
+
hub_revision = hub_revision,
|
| 618 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 619 |
+
metric_for_best_model = metric_for_best_model,
|
| 620 |
+
greater_is_better = greater_is_better,
|
| 621 |
+
ignore_data_skip = ignore_data_skip,
|
| 622 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 623 |
+
full_determinism = full_determinism,
|
| 624 |
+
seed = seed,
|
| 625 |
+
data_seed = data_seed,
|
| 626 |
+
use_cpu = use_cpu,
|
| 627 |
+
accelerator_config = accelerator_config,
|
| 628 |
+
parallelism_config = parallelism_config,
|
| 629 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 630 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 631 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 632 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 633 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 634 |
+
remove_unused_columns = remove_unused_columns,
|
| 635 |
+
label_names = label_names,
|
| 636 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 637 |
+
length_column_name = length_column_name,
|
| 638 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 639 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 640 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 641 |
+
ddp_backend = ddp_backend,
|
| 642 |
+
ddp_timeout = ddp_timeout,
|
| 643 |
+
fsdp = fsdp,
|
| 644 |
+
fsdp_config = fsdp_config,
|
| 645 |
+
deepspeed = deepspeed,
|
| 646 |
+
debug = debug,
|
| 647 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 648 |
+
do_train = do_train,
|
| 649 |
+
do_eval = do_eval,
|
| 650 |
+
do_predict = do_predict,
|
| 651 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 652 |
+
warmup_ratio = warmup_ratio,
|
| 653 |
+
logging_dir = logging_dir,
|
| 654 |
+
local_rank = local_rank,
|
| 655 |
+
model_init_kwargs = model_init_kwargs,
|
| 656 |
+
chat_template_path = chat_template_path,
|
| 657 |
+
dataset_text_field = dataset_text_field,
|
| 658 |
+
dataset_kwargs = dataset_kwargs,
|
| 659 |
+
dataset_num_proc = dataset_num_proc,
|
| 660 |
+
eos_token = eos_token,
|
| 661 |
+
pad_token = pad_token,
|
| 662 |
+
max_length = max_length,
|
| 663 |
+
packing = packing,
|
| 664 |
+
packing_strategy = packing_strategy,
|
| 665 |
+
padding_free = padding_free,
|
| 666 |
+
pad_to_multiple_of = pad_to_multiple_of,
|
| 667 |
+
eval_packing = eval_packing,
|
| 668 |
+
completion_only_loss = completion_only_loss,
|
| 669 |
+
assistant_only_loss = assistant_only_loss,
|
| 670 |
+
loss_type = loss_type,
|
| 671 |
+
activation_offloading = activation_offloading,
|
| 672 |
+
temperature = temperature,
|
| 673 |
+
lmbda = lmbda,
|
| 674 |
+
beta = beta,
|
| 675 |
+
max_new_tokens = max_new_tokens,
|
| 676 |
+
teacher_model_name_or_path = teacher_model_name_or_path,
|
| 677 |
+
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
| 678 |
+
disable_dropout = disable_dropout,
|
| 679 |
+
seq_kd = seq_kd,**kwargs)
|
| 680 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 681 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 682 |
+
if unsloth_grpo_mini_batch is not None:
|
| 683 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 684 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 685 |
+
else:
|
| 686 |
+
raise ValueError(
|
| 687 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 688 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 689 |
+
)
|
| 690 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 691 |
+
self.max_seq_length = max_seq_length
|
| 692 |
+
|
| 693 |
+
pass
|
| 694 |
+
|
| 695 |
+
class _UnslothGKDTrainer(SFTTrainer):
|
| 696 |
+
""""""
|
| 697 |
+
|
| 698 |
+
_tag_names = ["trl", "gkd"]
|
| 699 |
+
_name = "GKD"
|
| 700 |
+
_paper = {
|
| 701 |
+
"title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
| 702 |
+
"id": "2306.13649",
|
| 703 |
+
# docstyle-ignore
|
| 704 |
+
"citation": textwrap.dedent("""\
|
| 705 |
+
@inproceedings{agarwal2024on-policy,
|
| 706 |
+
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
| 707 |
+
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
| 708 |
+
year = 2024,
|
| 709 |
+
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
| 710 |
+
publisher = {OpenReview.net},
|
| 711 |
+
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
| 712 |
+
}"""),
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
def __init__(
|
| 716 |
+
self,
|
| 717 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 718 |
+
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 719 |
+
args: Optional[GKDConfig] = None,
|
| 720 |
+
data_collator: Optional[DataCollator] = None, # type: ignore
|
| 721 |
+
train_dataset: Optional[Dataset] = None,
|
| 722 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 723 |
+
processing_class: Optional[
|
| 724 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 725 |
+
] = None,
|
| 726 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 727 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 728 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 729 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 730 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 731 |
+
formatting_func: Optional[Callable] = None,
|
| 732 |
+
):
|
| 733 |
+
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
| 734 |
+
warnings.warn(
|
| 735 |
+
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
| 736 |
+
"it and want it to remain, please share your comments here: "
|
| 737 |
+
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
| 738 |
+
"TRL_EXPERIMENTAL_SILENCE=1."
|
| 739 |
+
)
|
| 740 |
+
# Ensure Trainer does not drop non-signature columns used by the collator [e.g., "prompts"]
|
| 741 |
+
args.remove_unused_columns = False
|
| 742 |
+
# Respect a user-provided data_collator; otherwise, provide a ChatML collator that
|
| 743 |
+
if data_collator is None:
|
| 744 |
+
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
|
| 745 |
+
|
| 746 |
+
# Ensure SFTTrainer does not pre-process the dataset when using a ChatML collator,
|
| 747 |
+
# so that raw conversational fields [e.g., "messages"] remain available to the collator.
|
| 748 |
+
if args.dataset_kwargs is None:
|
| 749 |
+
args.dataset_kwargs = {"skip_prepare_dataset": True}
|
| 750 |
+
else:
|
| 751 |
+
args.dataset_kwargs["skip_prepare_dataset"] = True
|
| 752 |
+
|
| 753 |
+
# Liger fused GKD loss [JSD]
|
| 754 |
+
self.use_liger_gkd_loss = False
|
| 755 |
+
if args.use_liger_kernel:
|
| 756 |
+
self.liger_jsd_loss = LigerFusedLinearJSDLoss(
|
| 757 |
+
beta=args.beta,
|
| 758 |
+
ignore_index=-100,
|
| 759 |
+
temperature=args.temperature,
|
| 760 |
+
compiled=False,
|
| 761 |
+
)
|
| 762 |
+
self.use_liger_gkd_loss = True
|
| 763 |
+
|
| 764 |
+
super().__init__(
|
| 765 |
+
model,
|
| 766 |
+
args=args,
|
| 767 |
+
data_collator=data_collator,
|
| 768 |
+
train_dataset=train_dataset,
|
| 769 |
+
eval_dataset=eval_dataset,
|
| 770 |
+
processing_class=processing_class,
|
| 771 |
+
compute_metrics=compute_metrics,
|
| 772 |
+
callbacks=callbacks,
|
| 773 |
+
optimizers=optimizers,
|
| 774 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 775 |
+
peft_config=peft_config,
|
| 776 |
+
formatting_func=formatting_func,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
if args.teacher_model_init_kwargs is None:
|
| 780 |
+
teacher_model_init_kwargs = {}
|
| 781 |
+
elif not isinstance(teacher_model, str):
|
| 782 |
+
raise ValueError(
|
| 783 |
+
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
| 784 |
+
)
|
| 785 |
+
else:
|
| 786 |
+
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
| 787 |
+
teacher_model_init_kwargs["dtype"] = (
|
| 788 |
+
teacher_model_init_kwargs["dtype"]
|
| 789 |
+
if teacher_model_init_kwargs["dtype"] in ["auto", None]
|
| 790 |
+
else getattr(torch, teacher_model_init_kwargs["dtype"])
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if isinstance(teacher_model, str):
|
| 794 |
+
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
| 795 |
+
|
| 796 |
+
# Disable dropout in the model
|
| 797 |
+
if args.disable_dropout:
|
| 798 |
+
disable_dropout_in_model(self.model)
|
| 799 |
+
|
| 800 |
+
if self.is_deepspeed_enabled:
|
| 801 |
+
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
|
| 802 |
+
else:
|
| 803 |
+
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
| 804 |
+
|
| 805 |
+
self.lmbda = args.lmbda
|
| 806 |
+
self.beta = args.beta
|
| 807 |
+
self.temperature = args.temperature
|
| 808 |
+
self.seq_kd = args.seq_kd
|
| 809 |
+
|
| 810 |
+
self.generation_config = GenerationConfig(
|
| 811 |
+
max_new_tokens=args.max_new_tokens,
|
| 812 |
+
temperature=args.temperature,
|
| 813 |
+
do_sample=True,
|
| 814 |
+
top_k=0,
|
| 815 |
+
use_cache=False if args.gradient_checkpointing else True,
|
| 816 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 817 |
+
)
|
| 818 |
+
# Set custom EOS tokens if they are specified by the model's generation
|
| 819 |
+
# config. This is important for models with the Llama 3 chat template,
|
| 820 |
+
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
| 821 |
+
# turns or messages.
|
| 822 |
+
if (
|
| 823 |
+
hasattr(self.model.generation_config, "eos_token_id")
|
| 824 |
+
and self.model.generation_config.eos_token_id is not None
|
| 825 |
+
):
|
| 826 |
+
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
| 827 |
+
|
| 828 |
+
@staticmethod
|
| 829 |
+
def generalized_jsd_loss(
|
| 830 |
+
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
| 831 |
+
):
|
| 832 |
+
"""
|
| 833 |
+
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
| 834 |
+
of https://huggingface.co/papers/2306.13649 for the definition.
|
| 835 |
+
|
| 836 |
+
Args:
|
| 837 |
+
student_logits:
|
| 838 |
+
Tensor of shape (batch_size, sequence_length, vocab_size)
|
| 839 |
+
teacher_logits:
|
| 840 |
+
Tensor of shape (batch_size, sequence_length, vocab_size)
|
| 841 |
+
labels:
|
| 842 |
+
Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing
|
| 843 |
+
loss
|
| 844 |
+
beta:
|
| 845 |
+
Interpolation coefficient between 0 and 1 (default: 0.5)
|
| 846 |
+
temperature:
|
| 847 |
+
Softmax temperature (default: 1.0)
|
| 848 |
+
reduction:
|
| 849 |
+
Specifies the reduction to apply to the output (default: 'batchmean')
|
| 850 |
+
|
| 851 |
+
Returns:
|
| 852 |
+
loss: Scalar tensor with the generalized JSD loss
|
| 853 |
+
"""
|
| 854 |
+
|
| 855 |
+
# Apply temperature scaling
|
| 856 |
+
student_logits = student_logits / temperature
|
| 857 |
+
teacher_logits = teacher_logits / temperature
|
| 858 |
+
|
| 859 |
+
# Compute log probabilities for student and probabilities for teacher
|
| 860 |
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
| 861 |
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
| 862 |
+
|
| 863 |
+
if beta == 0:
|
| 864 |
+
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
| 865 |
+
elif beta == 1:
|
| 866 |
+
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
| 867 |
+
else:
|
| 868 |
+
# Compute the log of the mixture distribution
|
| 869 |
+
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
| 870 |
+
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
| 871 |
+
mixture_log_probs = torch.logsumexp(
|
| 872 |
+
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
|
| 873 |
+
dim=0,
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
# Compute KL divergences using F.kl_div
|
| 877 |
+
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
| 878 |
+
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
| 879 |
+
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
| 880 |
+
|
| 881 |
+
# Compute the Generalized Jensen-Shannon Divergence
|
| 882 |
+
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
| 883 |
+
|
| 884 |
+
# Masking
|
| 885 |
+
if labels is not None:
|
| 886 |
+
mask = labels != -100
|
| 887 |
+
jsd = jsd[mask]
|
| 888 |
+
|
| 889 |
+
# Apply reduction
|
| 890 |
+
if reduction == "batchmean":
|
| 891 |
+
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0)
|
| 892 |
+
elif reduction == "sum":
|
| 893 |
+
return jsd.sum()
|
| 894 |
+
elif reduction == "mean":
|
| 895 |
+
return jsd.mean()
|
| 896 |
+
else:
|
| 897 |
+
return jsd
|
| 898 |
+
|
| 899 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 900 |
+
if self.use_liger_gkd_loss:
|
| 901 |
+
# Forward only through the base models (avoid lm_head to save memory)
|
| 902 |
+
unwrapped_student = self.accelerator.unwrap_model(model)
|
| 903 |
+
if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None:
|
| 904 |
+
base_student = unwrapped_student.get_decoder()
|
| 905 |
+
else:
|
| 906 |
+
base_student = getattr(
|
| 907 |
+
unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
student_outputs = base_student(
|
| 911 |
+
input_ids=inputs["input_ids"],
|
| 912 |
+
attention_mask=inputs["attention_mask"],
|
| 913 |
+
output_hidden_states=True,
|
| 914 |
+
use_cache=False,
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
self.teacher_model.eval()
|
| 918 |
+
unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model)
|
| 919 |
+
if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None:
|
| 920 |
+
base_teacher = unwrapped_teacher.get_decoder()
|
| 921 |
+
else:
|
| 922 |
+
base_teacher = getattr(
|
| 923 |
+
unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher
|
| 924 |
+
)
|
| 925 |
+
with torch.no_grad():
|
| 926 |
+
teacher_outputs = base_teacher(
|
| 927 |
+
input_ids=inputs["input_ids"],
|
| 928 |
+
attention_mask=inputs["attention_mask"],
|
| 929 |
+
output_hidden_states=True,
|
| 930 |
+
use_cache=False,
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
# hidden states (shifted)
|
| 934 |
+
student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous()
|
| 935 |
+
teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous()
|
| 936 |
+
|
| 937 |
+
# labels mask and labels (shifted)
|
| 938 |
+
labels_mask = inputs["labels"] != -100
|
| 939 |
+
masked_input_ids = torch.where(
|
| 940 |
+
labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100)
|
| 941 |
+
)
|
| 942 |
+
true_labels = masked_input_ids[:, 1:].contiguous()
|
| 943 |
+
|
| 944 |
+
# heads
|
| 945 |
+
student_head = unwrapped_student.get_output_embeddings()
|
| 946 |
+
teacher_head = unwrapped_teacher.get_output_embeddings()
|
| 947 |
+
|
| 948 |
+
# liger fused jsd loss
|
| 949 |
+
loss = self.liger_jsd_loss(
|
| 950 |
+
student_input=student_hidden,
|
| 951 |
+
student_weight=student_head.weight,
|
| 952 |
+
teacher_input=teacher_hidden,
|
| 953 |
+
teacher_weight=teacher_head.weight,
|
| 954 |
+
true_labels=true_labels,
|
| 955 |
+
student_bias=getattr(student_head, "bias", None),
|
| 956 |
+
teacher_bias=getattr(teacher_head, "bias", None),
|
| 957 |
+
)
|
| 958 |
+
else:
|
| 959 |
+
# compute student output
|
| 960 |
+
student_outputs = model(
|
| 961 |
+
input_ids=inputs["input_ids"],
|
| 962 |
+
attention_mask=inputs["attention_mask"],
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# compute teacher output in eval mode
|
| 966 |
+
self.teacher_model.eval()
|
| 967 |
+
with torch.no_grad():
|
| 968 |
+
teacher_outputs = self.teacher_model(
|
| 969 |
+
input_ids=inputs["input_ids"],
|
| 970 |
+
attention_mask=inputs["attention_mask"],
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
| 974 |
+
prompt_lengths = inputs["prompts"].shape[1]
|
| 975 |
+
shifted_student_logits = student_outputs.logits[:, prompt_lengths - 1 : -1, :]
|
| 976 |
+
shifted_teacher_logits = teacher_outputs.logits[:, prompt_lengths - 1 : -1, :]
|
| 977 |
+
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
| 978 |
+
|
| 979 |
+
# compute loss
|
| 980 |
+
loss = self.generalized_jsd_loss(
|
| 981 |
+
student_logits=shifted_student_logits,
|
| 982 |
+
teacher_logits=shifted_teacher_logits,
|
| 983 |
+
labels=shifted_labels,
|
| 984 |
+
beta=self.beta,
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
# empty cache
|
| 988 |
+
empty_cache()
|
| 989 |
+
|
| 990 |
+
# Return loss
|
| 991 |
+
return (loss, student_outputs) if return_outputs else loss
|
| 992 |
+
|
| 993 |
+
@staticmethod
|
| 994 |
+
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
| 995 |
+
# Generate output with respect to the prompt-only
|
| 996 |
+
generated_outputs = model.generate(
|
| 997 |
+
input_ids=inputs["prompts"],
|
| 998 |
+
attention_mask=inputs.get("prompt_attention_mask", None),
|
| 999 |
+
generation_config=generation_config,
|
| 1000 |
+
return_dict_in_generate=True,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
# Get the generated token IDs
|
| 1004 |
+
generated_tokens = generated_outputs.sequences
|
| 1005 |
+
# Calculate new attention mask
|
| 1006 |
+
new_attention_mask = torch.ones_like(generated_tokens)
|
| 1007 |
+
new_labels = generated_tokens.clone()
|
| 1008 |
+
|
| 1009 |
+
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
| 1010 |
+
if pad_token_id is not None:
|
| 1011 |
+
new_labels[new_labels == pad_token_id] = -100
|
| 1012 |
+
new_attention_mask[generated_tokens == pad_token_id] = 0
|
| 1013 |
+
|
| 1014 |
+
return generated_tokens, new_attention_mask, new_labels
|
| 1015 |
+
|
| 1016 |
+
def training_step(
|
| 1017 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 1018 |
+
) -> torch.Tensor:
|
| 1019 |
+
"""
|
| 1020 |
+
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
| 1021 |
+
|
| 1022 |
+
This method implements the on-policy learning approach described in the GKD paper. With probability
|
| 1023 |
+
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
|
| 1024 |
+
the original inputs.
|
| 1025 |
+
"""
|
| 1026 |
+
if self.seq_kd:
|
| 1027 |
+
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
| 1028 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
| 1029 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
| 1030 |
+
)
|
| 1031 |
+
inputs["input_ids"] = new_input_ids
|
| 1032 |
+
inputs["attention_mask"] = new_attention_mask
|
| 1033 |
+
inputs["labels"] = new_labels
|
| 1034 |
+
if random.random() <= self.lmbda:
|
| 1035 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 1036 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
| 1037 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
| 1038 |
+
)
|
| 1039 |
+
inputs["input_ids"] = new_input_ids
|
| 1040 |
+
inputs["attention_mask"] = new_attention_mask
|
| 1041 |
+
inputs["labels"] = new_labels
|
| 1042 |
+
|
| 1043 |
+
loss = super().training_step(model, inputs, num_items_in_batch)
|
| 1044 |
+
return loss
|
| 1045 |
+
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
| 1046 |
+
"""
|
| 1047 |
+
Trainer for Generalized Knowledge Distillation (GKD) of language models.
|
| 1048 |
+
|
| 1049 |
+
For details on GKD, see the paper: [On-Policy Distillation of Language Models: Learning from Self-Generated
|
| 1050 |
+
Mistakes](https://huggingface.co/papers/2306.13649).
|
| 1051 |
+
|
| 1052 |
+
Args:
|
| 1053 |
+
model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
|
| 1054 |
+
Model to be trained, or the string identifier of the model to be instantiated from a pretrained model.
|
| 1055 |
+
teacher_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
|
| 1056 |
+
Teacher model for knowledge distillation, or the string identifier of the model to be instantiated from a
|
| 1057 |
+
pretrained model.
|
| 1058 |
+
args ([`GKDConfig`], *optional*):
|
| 1059 |
+
Training arguments.
|
| 1060 |
+
data_collator ([`~transformers.DataCollator`], *optional*):
|
| 1061 |
+
Data collator to batch samples from the dataset. It defaults to a [`DataCollatorForChatML`] using the
|
| 1062 |
+
`processing_class`.
|
| 1063 |
+
train_dataset ([`~datasets.Dataset`], *optional*):
|
| 1064 |
+
Dataset for training.
|
| 1065 |
+
eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
|
| 1066 |
+
Dataset for evaluation.
|
| 1067 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 1068 |
+
Class to process the data.
|
| 1069 |
+
compute_metrics (`Callable`, *optional*):
|
| 1070 |
+
Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a
|
| 1071 |
+
dictionary string to float.
|
| 1072 |
+
callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
|
| 1073 |
+
Callbacks to use during training.
|
| 1074 |
+
optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
|
| 1075 |
+
Tuple containing the optimizer and the learning rate scheduler to use for training.
|
| 1076 |
+
preprocess_logits_for_metrics (`Callable`, *optional*):
|
| 1077 |
+
Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and
|
| 1078 |
+
return the logits to be used for metrics computation.
|
| 1079 |
+
peft_config ([`~peft.PeftConfig`], *optional*):
|
| 1080 |
+
PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
|
| 1081 |
+
wrapped with the specified PEFT adapter.
|
| 1082 |
+
formatting_func (`Callable`, *optional*):
|
| 1083 |
+
Function to format the dataset. Must take in an example and return an example.
|
| 1084 |
+
|
| 1085 |
+
"""
|
| 1086 |
+
def __init__(
|
| 1087 |
+
self,
|
| 1088 |
+
model = None,
|
| 1089 |
+
teacher_model = None,
|
| 1090 |
+
args = None,
|
| 1091 |
+
data_collator = None,
|
| 1092 |
+
train_dataset = None,
|
| 1093 |
+
eval_dataset = None,
|
| 1094 |
+
processing_class = None,
|
| 1095 |
+
compute_metrics = None,
|
| 1096 |
+
callbacks = None,
|
| 1097 |
+
preprocess_logits_for_metrics = None,
|
| 1098 |
+
peft_config = None,
|
| 1099 |
+
formatting_func = None,
|
| 1100 |
+
**kwargs
|
| 1101 |
+
):
|
| 1102 |
+
if args is None: args = UnslothGKDConfig()
|
| 1103 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1104 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1105 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1106 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1107 |
+
force_float32 = False
|
| 1108 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1109 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1110 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1111 |
+
force_float32 = True
|
| 1112 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1113 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1114 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1115 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1116 |
+
dtype = _get_dtype(dtype)
|
| 1117 |
+
float16 = dtype == torch.float16
|
| 1118 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1119 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1120 |
+
if force_float32:
|
| 1121 |
+
# Forced float32 training
|
| 1122 |
+
args.fp16 = False
|
| 1123 |
+
args.bf16 = False
|
| 1124 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1125 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1126 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1127 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1128 |
+
# Mixed precision training
|
| 1129 |
+
args.fp16 = float16
|
| 1130 |
+
args.bf16 = not float16
|
| 1131 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1132 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1133 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1134 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1135 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1136 |
+
args.fp16 = False
|
| 1137 |
+
args.bf16 = False
|
| 1138 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1139 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1140 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1141 |
+
|
| 1142 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1143 |
+
args.eval_strategy = 'steps'
|
| 1144 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1145 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1146 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1147 |
+
from transformers import __version__ as transformers_version
|
| 1148 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1149 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1150 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1151 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1152 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1153 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1154 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1155 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1156 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1157 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1158 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1159 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1160 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1161 |
+
if force_float32:
|
| 1162 |
+
args.bf16_full_eval = False
|
| 1163 |
+
args.fp16_full_eval = False
|
| 1164 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1165 |
+
args.bf16_full_eval = True
|
| 1166 |
+
args.fp16_full_eval = False
|
| 1167 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1168 |
+
args.bf16_full_eval = args.bf16
|
| 1169 |
+
args.fp16_full_eval = args.fp16
|
| 1170 |
+
_output_logits = False
|
| 1171 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1172 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1173 |
+
if _output_logits:
|
| 1174 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1175 |
+
if model is not None:
|
| 1176 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1177 |
+
if _warnings_issued is None:
|
| 1178 |
+
model.warnings_issued = {}
|
| 1179 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1180 |
+
try:
|
| 1181 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1182 |
+
except Exception:
|
| 1183 |
+
model.warnings_issued = {}
|
| 1184 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1185 |
+
pass
|
| 1186 |
+
else:
|
| 1187 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1188 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1189 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1190 |
+
max_seq_length = model.max_seq_length
|
| 1191 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1192 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1193 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1194 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1195 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1196 |
+
args.max_seq_length = model_max_seq_length
|
| 1197 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1198 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1199 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1200 |
+
if 'processing_class' in locals():
|
| 1201 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1202 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1203 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1204 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1205 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1206 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1207 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1208 |
+
__tokenizer,
|
| 1209 |
+
mlm = False,
|
| 1210 |
+
mlm_probability = 0.0,
|
| 1211 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1212 |
+
)
|
| 1213 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1214 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1215 |
+
__tokenizer,
|
| 1216 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1217 |
+
)
|
| 1218 |
+
else:
|
| 1219 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1220 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1221 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1222 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1223 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1224 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1225 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1226 |
+
__tokenizer.tokenizer,
|
| 1227 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1228 |
+
)
|
| 1229 |
+
else:
|
| 1230 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1231 |
+
__tokenizer.tokenizer,
|
| 1232 |
+
mlm = False,
|
| 1233 |
+
mlm_probability = 0.0,
|
| 1234 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1235 |
+
)
|
| 1236 |
+
other_metrics = []
|
| 1237 |
+
|
| 1238 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1239 |
+
PatchRLStatistics('gkd_trainer', other_metrics)
|
| 1240 |
+
|
| 1241 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1242 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1243 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1244 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1245 |
+
args._n_gpu = 1
|
| 1246 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1247 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1248 |
+
super().__init__(
|
| 1249 |
+
model = model,
|
| 1250 |
+
teacher_model = teacher_model,
|
| 1251 |
+
args = args,
|
| 1252 |
+
data_collator = data_collator,
|
| 1253 |
+
train_dataset = train_dataset,
|
| 1254 |
+
eval_dataset = eval_dataset,
|
| 1255 |
+
processing_class = processing_class,
|
| 1256 |
+
compute_metrics = compute_metrics,
|
| 1257 |
+
callbacks = callbacks,
|
| 1258 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1259 |
+
peft_config = peft_config,
|
| 1260 |
+
formatting_func = formatting_func,**kwargs)
|
| 1261 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1262 |
+
model.for_inference()
|
| 1263 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1264 |
+
self.neftune_hook_handle.remove()
|
| 1265 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1266 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1267 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1268 |
+
pass
|
| 1269 |
+
if hasattr(self, 'accelerator'):
|
| 1270 |
+
scaler = self.accelerator.scaler
|
| 1271 |
+
current_model = model
|
| 1272 |
+
while hasattr(current_model, 'model'):
|
| 1273 |
+
current_model.accelerator_scaler = scaler
|
| 1274 |
+
current_model = current_model.model
|
| 1275 |
+
current_model.accelerator_scaler = scaler
|
| 1276 |
+
pass
|
| 1277 |
+
if hasattr(self, 'train'):
|
| 1278 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1279 |
+
pass
|
| 1280 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1281 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1282 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1283 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1284 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1285 |
+
pass
|
| 1286 |
+
|
| 1287 |
+
pass
|
unsloth_compiled_cache/UnslothKTOTrainer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/UnslothORPOTrainer.py
ADDED
|
@@ -0,0 +1,1860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothORPOConfig(ORPOConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`ORPOTrainer`].
|
| 328 |
+
|
| 329 |
+
This class includes only the parameters that are specific to ORPO training. For a full list of training arguments,
|
| 330 |
+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
| 331 |
+
differ from those in [`~transformers.TrainingArguments`].
|
| 332 |
+
|
| 333 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 334 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 335 |
+
command line.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 339 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 340 |
+
to use the default data collator.
|
| 341 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 342 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 343 |
+
max_completion_length (`int`, *optional*):
|
| 344 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 345 |
+
and your model is an encoder-decoder.
|
| 346 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 347 |
+
Parameter controlling the relative ratio loss weight in the ORPO loss. In the
|
| 348 |
+
[paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the
|
| 349 |
+
[code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
|
| 350 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 351 |
+
Whether to disable dropout in the model.
|
| 352 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 353 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
| 354 |
+
padding_value (`int`, *optional*):
|
| 355 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 356 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 357 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 358 |
+
This argument is required if you want to use the default data collator.
|
| 359 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 360 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
| 361 |
+
is_encoder_decoder (`bool`, *optional*):
|
| 362 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 363 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 364 |
+
model_init_kwargs (`dict[str, Any]`, *optional*):
|
| 365 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 366 |
+
string.
|
| 367 |
+
dataset_num_proc (`int`, *optional*):
|
| 368 |
+
Number of processes to use for processing the dataset.
|
| 369 |
+
|
| 370 |
+
"""
|
| 371 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 372 |
+
default = None,
|
| 373 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 374 |
+
)
|
| 375 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 376 |
+
default = -1,
|
| 377 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 378 |
+
)
|
| 379 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 380 |
+
default = None,
|
| 381 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 382 |
+
)
|
| 383 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 384 |
+
default = None,
|
| 385 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 386 |
+
)
|
| 387 |
+
max_seq_length : Optional[int] = field(
|
| 388 |
+
default = None,
|
| 389 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 390 |
+
)
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
output_dir = None,
|
| 394 |
+
per_device_train_batch_size = 4,
|
| 395 |
+
num_train_epochs = 3.0,
|
| 396 |
+
max_steps = -1,
|
| 397 |
+
learning_rate = 5e-05,
|
| 398 |
+
lr_scheduler_type = 'linear',
|
| 399 |
+
lr_scheduler_kwargs = None,
|
| 400 |
+
warmup_steps = 0.1,
|
| 401 |
+
optim = 'adamw_8bit',
|
| 402 |
+
optim_args = None,
|
| 403 |
+
weight_decay = 0.01,
|
| 404 |
+
adam_beta1 = 0.9,
|
| 405 |
+
adam_beta2 = 0.999,
|
| 406 |
+
adam_epsilon = 1e-08,
|
| 407 |
+
optim_target_modules = None,
|
| 408 |
+
gradient_accumulation_steps = 2,
|
| 409 |
+
average_tokens_across_devices = True,
|
| 410 |
+
max_grad_norm = 1.0,
|
| 411 |
+
label_smoothing_factor = 0.0,
|
| 412 |
+
bf16 = False,
|
| 413 |
+
fp16 = False,
|
| 414 |
+
bf16_full_eval = False,
|
| 415 |
+
fp16_full_eval = False,
|
| 416 |
+
tf32 = None,
|
| 417 |
+
gradient_checkpointing = True,
|
| 418 |
+
gradient_checkpointing_kwargs = None,
|
| 419 |
+
torch_compile = False,
|
| 420 |
+
torch_compile_backend = None,
|
| 421 |
+
torch_compile_mode = None,
|
| 422 |
+
use_liger_kernel = False,
|
| 423 |
+
liger_kernel_config = None,
|
| 424 |
+
use_cache = False,
|
| 425 |
+
neftune_noise_alpha = None,
|
| 426 |
+
torch_empty_cache_steps = 250,
|
| 427 |
+
auto_find_batch_size = False,
|
| 428 |
+
logging_strategy = 'steps',
|
| 429 |
+
logging_steps = 1,
|
| 430 |
+
logging_first_step = False,
|
| 431 |
+
log_on_each_node = True,
|
| 432 |
+
logging_nan_inf_filter = False,
|
| 433 |
+
include_num_input_tokens_seen = False,
|
| 434 |
+
log_level = 'passive',
|
| 435 |
+
log_level_replica = 'warning',
|
| 436 |
+
disable_tqdm = None,
|
| 437 |
+
report_to = 'none',
|
| 438 |
+
run_name = None,
|
| 439 |
+
project = 'huggingface',
|
| 440 |
+
trackio_space_id = 'trackio',
|
| 441 |
+
eval_strategy = 'no',
|
| 442 |
+
eval_steps = None,
|
| 443 |
+
eval_delay = 0,
|
| 444 |
+
per_device_eval_batch_size = 4,
|
| 445 |
+
prediction_loss_only = False,
|
| 446 |
+
eval_on_start = False,
|
| 447 |
+
eval_do_concat_batches = True,
|
| 448 |
+
eval_use_gather_object = False,
|
| 449 |
+
eval_accumulation_steps = 2,
|
| 450 |
+
batch_eval_metrics = False,
|
| 451 |
+
save_only_model = False,
|
| 452 |
+
save_strategy = 'steps',
|
| 453 |
+
save_steps = 500,
|
| 454 |
+
save_on_each_node = False,
|
| 455 |
+
save_total_limit = None,
|
| 456 |
+
enable_jit_checkpoint = False,
|
| 457 |
+
push_to_hub = False,
|
| 458 |
+
hub_token = None,
|
| 459 |
+
hub_private_repo = None,
|
| 460 |
+
hub_model_id = None,
|
| 461 |
+
hub_strategy = 'every_save',
|
| 462 |
+
hub_always_push = False,
|
| 463 |
+
hub_revision = None,
|
| 464 |
+
load_best_model_at_end = False,
|
| 465 |
+
metric_for_best_model = None,
|
| 466 |
+
greater_is_better = None,
|
| 467 |
+
ignore_data_skip = False,
|
| 468 |
+
restore_callback_states_from_checkpoint = False,
|
| 469 |
+
full_determinism = False,
|
| 470 |
+
seed = 3407,
|
| 471 |
+
data_seed = 3407,
|
| 472 |
+
use_cpu = False,
|
| 473 |
+
accelerator_config = None,
|
| 474 |
+
parallelism_config = None,
|
| 475 |
+
dataloader_drop_last = False,
|
| 476 |
+
dataloader_num_workers = 0,
|
| 477 |
+
dataloader_pin_memory = True,
|
| 478 |
+
dataloader_persistent_workers = False,
|
| 479 |
+
dataloader_prefetch_factor = None,
|
| 480 |
+
remove_unused_columns = True,
|
| 481 |
+
label_names = None,
|
| 482 |
+
train_sampling_strategy = 'random',
|
| 483 |
+
length_column_name = 'length',
|
| 484 |
+
ddp_find_unused_parameters = None,
|
| 485 |
+
ddp_bucket_cap_mb = None,
|
| 486 |
+
ddp_broadcast_buffers = None,
|
| 487 |
+
ddp_backend = None,
|
| 488 |
+
ddp_timeout = 1800,
|
| 489 |
+
fsdp = None,
|
| 490 |
+
fsdp_config = None,
|
| 491 |
+
deepspeed = None,
|
| 492 |
+
debug = '',
|
| 493 |
+
skip_memory_metrics = True,
|
| 494 |
+
do_train = False,
|
| 495 |
+
do_eval = False,
|
| 496 |
+
do_predict = False,
|
| 497 |
+
resume_from_checkpoint = None,
|
| 498 |
+
warmup_ratio = None,
|
| 499 |
+
logging_dir = None,
|
| 500 |
+
local_rank = -1,
|
| 501 |
+
max_length = 1024,
|
| 502 |
+
max_prompt_length = 512,
|
| 503 |
+
max_completion_length = None,
|
| 504 |
+
beta = 0.1,
|
| 505 |
+
disable_dropout = True,
|
| 506 |
+
label_pad_token_id = -100,
|
| 507 |
+
padding_value = None,
|
| 508 |
+
truncation_mode = 'keep_end',
|
| 509 |
+
generate_during_eval = False,
|
| 510 |
+
is_encoder_decoder = None,
|
| 511 |
+
model_init_kwargs = None,
|
| 512 |
+
dataset_num_proc = None,
|
| 513 |
+
vllm_sampling_params = None,
|
| 514 |
+
unsloth_num_chunks = -1,
|
| 515 |
+
unsloth_logit_chunk_multiplier = None,
|
| 516 |
+
unsloth_grpo_mini_batch = None,
|
| 517 |
+
max_seq_length = None,
|
| 518 |
+
**kwargs,
|
| 519 |
+
):
|
| 520 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 521 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 522 |
+
if num_train_epochs is None:
|
| 523 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 524 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 525 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 526 |
+
save_strategy = 'no'
|
| 527 |
+
import multiprocessing as _mp
|
| 528 |
+
if _mp.get_start_method() != 'fork':
|
| 529 |
+
dataset_num_proc = None
|
| 530 |
+
elif dataset_num_proc is None:
|
| 531 |
+
import psutil
|
| 532 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 533 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 534 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 535 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 536 |
+
|
| 537 |
+
super().__init__(
|
| 538 |
+
output_dir = output_dir,
|
| 539 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 540 |
+
num_train_epochs = num_train_epochs,
|
| 541 |
+
max_steps = max_steps,
|
| 542 |
+
learning_rate = learning_rate,
|
| 543 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 544 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 545 |
+
warmup_steps = warmup_steps,
|
| 546 |
+
optim = optim,
|
| 547 |
+
optim_args = optim_args,
|
| 548 |
+
weight_decay = weight_decay,
|
| 549 |
+
adam_beta1 = adam_beta1,
|
| 550 |
+
adam_beta2 = adam_beta2,
|
| 551 |
+
adam_epsilon = adam_epsilon,
|
| 552 |
+
optim_target_modules = optim_target_modules,
|
| 553 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 554 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 555 |
+
max_grad_norm = max_grad_norm,
|
| 556 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 557 |
+
bf16 = bf16,
|
| 558 |
+
fp16 = fp16,
|
| 559 |
+
bf16_full_eval = bf16_full_eval,
|
| 560 |
+
fp16_full_eval = fp16_full_eval,
|
| 561 |
+
tf32 = tf32,
|
| 562 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 563 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 564 |
+
torch_compile = torch_compile,
|
| 565 |
+
torch_compile_backend = torch_compile_backend,
|
| 566 |
+
torch_compile_mode = torch_compile_mode,
|
| 567 |
+
use_liger_kernel = use_liger_kernel,
|
| 568 |
+
liger_kernel_config = liger_kernel_config,
|
| 569 |
+
use_cache = use_cache,
|
| 570 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 571 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 572 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 573 |
+
logging_strategy = logging_strategy,
|
| 574 |
+
logging_steps = logging_steps,
|
| 575 |
+
logging_first_step = logging_first_step,
|
| 576 |
+
log_on_each_node = log_on_each_node,
|
| 577 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 578 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 579 |
+
log_level = log_level,
|
| 580 |
+
log_level_replica = log_level_replica,
|
| 581 |
+
disable_tqdm = disable_tqdm,
|
| 582 |
+
report_to = report_to,
|
| 583 |
+
run_name = run_name,
|
| 584 |
+
project = project,
|
| 585 |
+
trackio_space_id = trackio_space_id,
|
| 586 |
+
eval_strategy = eval_strategy,
|
| 587 |
+
eval_steps = eval_steps,
|
| 588 |
+
eval_delay = eval_delay,
|
| 589 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 590 |
+
prediction_loss_only = prediction_loss_only,
|
| 591 |
+
eval_on_start = eval_on_start,
|
| 592 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 593 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 594 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 595 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 596 |
+
save_only_model = save_only_model,
|
| 597 |
+
save_strategy = save_strategy,
|
| 598 |
+
save_steps = save_steps,
|
| 599 |
+
save_on_each_node = save_on_each_node,
|
| 600 |
+
save_total_limit = save_total_limit,
|
| 601 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 602 |
+
push_to_hub = push_to_hub,
|
| 603 |
+
hub_token = hub_token,
|
| 604 |
+
hub_private_repo = hub_private_repo,
|
| 605 |
+
hub_model_id = hub_model_id,
|
| 606 |
+
hub_strategy = hub_strategy,
|
| 607 |
+
hub_always_push = hub_always_push,
|
| 608 |
+
hub_revision = hub_revision,
|
| 609 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 610 |
+
metric_for_best_model = metric_for_best_model,
|
| 611 |
+
greater_is_better = greater_is_better,
|
| 612 |
+
ignore_data_skip = ignore_data_skip,
|
| 613 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 614 |
+
full_determinism = full_determinism,
|
| 615 |
+
seed = seed,
|
| 616 |
+
data_seed = data_seed,
|
| 617 |
+
use_cpu = use_cpu,
|
| 618 |
+
accelerator_config = accelerator_config,
|
| 619 |
+
parallelism_config = parallelism_config,
|
| 620 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 621 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 622 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 623 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 624 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 625 |
+
remove_unused_columns = remove_unused_columns,
|
| 626 |
+
label_names = label_names,
|
| 627 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 628 |
+
length_column_name = length_column_name,
|
| 629 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 630 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 631 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 632 |
+
ddp_backend = ddp_backend,
|
| 633 |
+
ddp_timeout = ddp_timeout,
|
| 634 |
+
fsdp = fsdp,
|
| 635 |
+
fsdp_config = fsdp_config,
|
| 636 |
+
deepspeed = deepspeed,
|
| 637 |
+
debug = debug,
|
| 638 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 639 |
+
do_train = do_train,
|
| 640 |
+
do_eval = do_eval,
|
| 641 |
+
do_predict = do_predict,
|
| 642 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 643 |
+
warmup_ratio = warmup_ratio,
|
| 644 |
+
logging_dir = logging_dir,
|
| 645 |
+
local_rank = local_rank,
|
| 646 |
+
max_length = max_length,
|
| 647 |
+
max_prompt_length = max_prompt_length,
|
| 648 |
+
max_completion_length = max_completion_length,
|
| 649 |
+
beta = beta,
|
| 650 |
+
disable_dropout = disable_dropout,
|
| 651 |
+
label_pad_token_id = label_pad_token_id,
|
| 652 |
+
padding_value = padding_value,
|
| 653 |
+
truncation_mode = truncation_mode,
|
| 654 |
+
generate_during_eval = generate_during_eval,
|
| 655 |
+
is_encoder_decoder = is_encoder_decoder,
|
| 656 |
+
model_init_kwargs = model_init_kwargs,
|
| 657 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 658 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 659 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 660 |
+
if unsloth_grpo_mini_batch is not None:
|
| 661 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 662 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 663 |
+
else:
|
| 664 |
+
raise ValueError(
|
| 665 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 666 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 667 |
+
)
|
| 668 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 669 |
+
self.max_seq_length = max_seq_length
|
| 670 |
+
|
| 671 |
+
pass
|
| 672 |
+
|
| 673 |
+
class _UnslothORPOTrainer(BaseTrainer):
|
| 674 |
+
r""""""
|
| 675 |
+
|
| 676 |
+
_tag_names = ["trl", "orpo"]
|
| 677 |
+
_name = "ORPO"
|
| 678 |
+
_paper = {
|
| 679 |
+
"title": "ORPO: Monolithic Preference Optimization without Reference Model",
|
| 680 |
+
"id": "2403.07691",
|
| 681 |
+
# docstyle-ignore
|
| 682 |
+
"citation": textwrap.dedent("""\
|
| 683 |
+
@article{hong2024orpo,
|
| 684 |
+
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
|
| 685 |
+
author = {Jiwoo Hong and Noah Lee and James Thorne},
|
| 686 |
+
year = 2024,
|
| 687 |
+
eprint = {arXiv:2403.07691}
|
| 688 |
+
}"""),
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
def __init__(
|
| 692 |
+
self,
|
| 693 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 694 |
+
args: Optional[ORPOConfig] = None,
|
| 695 |
+
data_collator: Optional[DataCollator] = None,
|
| 696 |
+
train_dataset: Optional[Dataset] = None,
|
| 697 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 698 |
+
processing_class: Optional[
|
| 699 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 700 |
+
] = None,
|
| 701 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 702 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 703 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 704 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 705 |
+
peft_config: Optional[dict] = None,
|
| 706 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 707 |
+
):
|
| 708 |
+
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
| 709 |
+
warnings.warn(
|
| 710 |
+
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
| 711 |
+
"it and want it to remain, please share your comments here: "
|
| 712 |
+
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
| 713 |
+
"TRL_EXPERIMENTAL_SILENCE=1."
|
| 714 |
+
)
|
| 715 |
+
if args.model_init_kwargs is None:
|
| 716 |
+
model_init_kwargs = {}
|
| 717 |
+
elif not isinstance(model, str):
|
| 718 |
+
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
| 719 |
+
else:
|
| 720 |
+
model_init_kwargs = args.model_init_kwargs
|
| 721 |
+
dtype = model_init_kwargs.get("dtype")
|
| 722 |
+
if dtype is not None:
|
| 723 |
+
# Convert to `torch.dtype` if an str is passed
|
| 724 |
+
if isinstance(dtype, str) and dtype != "auto":
|
| 725 |
+
dtype = getattr(torch, dtype)
|
| 726 |
+
if dtype != "auto" and not isinstance(dtype, torch.dtype):
|
| 727 |
+
raise ValueError(
|
| 728 |
+
f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
|
| 729 |
+
)
|
| 730 |
+
model_init_kwargs["dtype"] = dtype
|
| 731 |
+
|
| 732 |
+
if isinstance(model, str):
|
| 733 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 734 |
+
|
| 735 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 736 |
+
# has been called in order to properly call autocast if needed.
|
| 737 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 738 |
+
|
| 739 |
+
if not is_peft_available() and peft_config is not None:
|
| 740 |
+
raise ValueError(
|
| 741 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 742 |
+
)
|
| 743 |
+
elif is_peft_available() and peft_config is not None:
|
| 744 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 745 |
+
if isinstance(model, PeftModel):
|
| 746 |
+
model = model.merge_and_unload()
|
| 747 |
+
|
| 748 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 749 |
+
_support_gc_kwargs = hasattr(
|
| 750 |
+
args, "gradient_checkpointing_kwargs"
|
| 751 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 752 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 756 |
+
|
| 757 |
+
if _support_gc_kwargs:
|
| 758 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 759 |
+
|
| 760 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 761 |
+
elif args.gradient_checkpointing:
|
| 762 |
+
# For backward compatibility with older versions of transformers
|
| 763 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 764 |
+
model.enable_input_require_grads()
|
| 765 |
+
else:
|
| 766 |
+
|
| 767 |
+
def make_inputs_require_grad(module, input, output):
|
| 768 |
+
output.requires_grad_(True)
|
| 769 |
+
|
| 770 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 771 |
+
|
| 772 |
+
# get peft model with the given config
|
| 773 |
+
model = model
|
| 774 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 775 |
+
peft_module_casting_to_bf16(model)
|
| 776 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 777 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 778 |
+
|
| 779 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 780 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 781 |
+
# fail or completely fail.
|
| 782 |
+
elif args.gradient_checkpointing:
|
| 783 |
+
# For backward compatibility with older versions of transformers
|
| 784 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 785 |
+
model.enable_input_require_grads()
|
| 786 |
+
else:
|
| 787 |
+
|
| 788 |
+
def make_inputs_require_grad(module, input, output):
|
| 789 |
+
output.requires_grad_(True)
|
| 790 |
+
|
| 791 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 792 |
+
|
| 793 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 794 |
+
raise ValueError(
|
| 795 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 796 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
if model is not None:
|
| 800 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 801 |
+
elif args.is_encoder_decoder is None:
|
| 802 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 803 |
+
else:
|
| 804 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 805 |
+
|
| 806 |
+
if self.is_encoder_decoder:
|
| 807 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
| 808 |
+
self.pad_token_id = model.config.pad_token_id
|
| 809 |
+
|
| 810 |
+
if processing_class is None:
|
| 811 |
+
raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
|
| 812 |
+
if args.max_length is None:
|
| 813 |
+
logger.warning(
|
| 814 |
+
"`max_length` is not set in the ORPOConfig's init"
|
| 815 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
| 816 |
+
)
|
| 817 |
+
max_length = 512
|
| 818 |
+
else:
|
| 819 |
+
max_length = args.max_length
|
| 820 |
+
if args.max_prompt_length is None:
|
| 821 |
+
logger.warning(
|
| 822 |
+
"`max_prompt_length` is not set in the ORPOConfig's init"
|
| 823 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 824 |
+
)
|
| 825 |
+
max_prompt_length = 128
|
| 826 |
+
else:
|
| 827 |
+
max_prompt_length = args.max_prompt_length
|
| 828 |
+
|
| 829 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 830 |
+
logger.warning(
|
| 831 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
| 832 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 833 |
+
)
|
| 834 |
+
self.max_completion_length = 128
|
| 835 |
+
else:
|
| 836 |
+
self.max_completion_length = args.max_completion_length
|
| 837 |
+
|
| 838 |
+
if data_collator is None:
|
| 839 |
+
data_collator = DPODataCollatorWithPadding(
|
| 840 |
+
pad_token_id=processing_class.pad_token_id,
|
| 841 |
+
label_pad_token_id=args.label_pad_token_id,
|
| 842 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
if args.remove_unused_columns:
|
| 846 |
+
args.remove_unused_columns = False
|
| 847 |
+
# warn users
|
| 848 |
+
logger.warning(
|
| 849 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
| 850 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
self.use_dpo_data_collator = True
|
| 854 |
+
else:
|
| 855 |
+
self.use_dpo_data_collator = False
|
| 856 |
+
|
| 857 |
+
# Disable dropout in the model and reference model
|
| 858 |
+
if args.disable_dropout:
|
| 859 |
+
disable_dropout_in_model(model)
|
| 860 |
+
|
| 861 |
+
self.max_length = max_length
|
| 862 |
+
self.generate_during_eval = args.generate_during_eval
|
| 863 |
+
self.label_pad_token_id = args.label_pad_token_id
|
| 864 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 865 |
+
self.max_prompt_length = max_prompt_length
|
| 866 |
+
self.truncation_mode = args.truncation_mode
|
| 867 |
+
self.processing_class = processing_class
|
| 868 |
+
|
| 869 |
+
self.beta = args.beta
|
| 870 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 871 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 872 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 873 |
+
logger.warning(
|
| 874 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 875 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 876 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 877 |
+
"loss.",
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 881 |
+
|
| 882 |
+
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 883 |
+
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
|
| 884 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
| 885 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 886 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 887 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 888 |
+
# that the warning has already been issued.
|
| 889 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 890 |
+
|
| 891 |
+
# Compute that only on the main process for faster data processing.
|
| 892 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 893 |
+
with PartialState().main_process_first():
|
| 894 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
| 895 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 896 |
+
train_dataset = train_dataset.map(
|
| 897 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 898 |
+
)
|
| 899 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 900 |
+
if eval_dataset is not None:
|
| 901 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 902 |
+
eval_dataset = eval_dataset.map(
|
| 903 |
+
maybe_apply_chat_template,
|
| 904 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 905 |
+
num_proc=args.dataset_num_proc,
|
| 906 |
+
)
|
| 907 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 908 |
+
|
| 909 |
+
super().__init__(
|
| 910 |
+
model=model,
|
| 911 |
+
args=args,
|
| 912 |
+
data_collator=data_collator,
|
| 913 |
+
train_dataset=train_dataset,
|
| 914 |
+
eval_dataset=eval_dataset,
|
| 915 |
+
processing_class=processing_class,
|
| 916 |
+
model_init=model_init,
|
| 917 |
+
compute_metrics=compute_metrics,
|
| 918 |
+
callbacks=callbacks,
|
| 919 |
+
optimizers=optimizers,
|
| 920 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 924 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 925 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 926 |
+
self.model_accepts_loss_kwargs = False
|
| 927 |
+
|
| 928 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 929 |
+
if hasattr(self.model, "add_model_tags"):
|
| 930 |
+
self.model.add_model_tags(self._tag_names)
|
| 931 |
+
|
| 932 |
+
if not hasattr(self, "accelerator"):
|
| 933 |
+
raise AttributeError(
|
| 934 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
def build_tokenized_answer(self, prompt, answer):
|
| 938 |
+
"""
|
| 939 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
|
| 940 |
+
b)[len(enc(a)):]`. Reference:
|
| 941 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 942 |
+
"""
|
| 943 |
+
|
| 944 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
| 945 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
| 946 |
+
|
| 947 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
| 948 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
| 949 |
+
|
| 950 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 951 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
| 952 |
+
|
| 953 |
+
# Prepare input tokens for token by token comparison
|
| 954 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
| 955 |
+
|
| 956 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
| 957 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
| 958 |
+
|
| 959 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 960 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
| 961 |
+
# on the last token from the prompt being different when tokenized on its own
|
| 962 |
+
# vs when done as prompt+answer.
|
| 963 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
| 964 |
+
|
| 965 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 966 |
+
# last token has changed due to merging.
|
| 967 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
| 968 |
+
response_token_ids_start_idx -= 1
|
| 969 |
+
|
| 970 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
| 971 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
| 972 |
+
|
| 973 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
| 974 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 975 |
+
|
| 976 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
| 977 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
| 978 |
+
|
| 979 |
+
return dict(
|
| 980 |
+
prompt_input_ids=prompt_input_ids,
|
| 981 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 982 |
+
input_ids=answer_input_ids,
|
| 983 |
+
attention_mask=answer_attention_mask,
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
| 987 |
+
"""Tokenize a single row from a ORPO specific dataset.
|
| 988 |
+
|
| 989 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
|
| 990 |
+
chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
|
| 991 |
+
we truncate the chosen/rejected.
|
| 992 |
+
|
| 993 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
|
| 994 |
+
of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
|
| 995 |
+
"""
|
| 996 |
+
batch = {}
|
| 997 |
+
prompt = feature["prompt"]
|
| 998 |
+
chosen = feature["chosen"]
|
| 999 |
+
rejected = feature["rejected"]
|
| 1000 |
+
|
| 1001 |
+
if not self.is_encoder_decoder:
|
| 1002 |
+
# Check issues below for more details
|
| 1003 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
| 1004 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 1005 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 1006 |
+
|
| 1007 |
+
if not isinstance(prompt, str):
|
| 1008 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 1009 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
| 1010 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
| 1011 |
+
|
| 1012 |
+
if not isinstance(chosen, str):
|
| 1013 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
| 1014 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
| 1015 |
+
|
| 1016 |
+
if not isinstance(rejected, str):
|
| 1017 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
| 1018 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
| 1019 |
+
|
| 1020 |
+
# Last prompt token might get merged by tokenizer and
|
| 1021 |
+
# it should not be included for generation if that happens
|
| 1022 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
| 1023 |
+
|
| 1024 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
| 1025 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
| 1026 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
| 1027 |
+
|
| 1028 |
+
for k, v in prompt_tokens.items():
|
| 1029 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
| 1030 |
+
|
| 1031 |
+
# Make sure prompts only have one different token at most an
|
| 1032 |
+
# and length only differs by 1 at most
|
| 1033 |
+
num_diff_tokens = sum(
|
| 1034 |
+
a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
|
| 1035 |
+
)
|
| 1036 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
| 1037 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
| 1038 |
+
raise ValueError(
|
| 1039 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
| 1040 |
+
"last token due to tokenizer merge ops."
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
| 1044 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
| 1045 |
+
self.processing_class.bos_token_id,
|
| 1046 |
+
prompt_len_input_ids,
|
| 1047 |
+
prompt_tokens,
|
| 1048 |
+
chosen_prompt_len_input_ids,
|
| 1049 |
+
chosen_tokens,
|
| 1050 |
+
rejected_prompt_len_input_ids,
|
| 1051 |
+
rejected_tokens,
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
| 1055 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
| 1056 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
| 1060 |
+
|
| 1061 |
+
# if combined sequence is too long, truncate the prompt
|
| 1062 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
| 1063 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 1064 |
+
if self.truncation_mode == "keep_start":
|
| 1065 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 1066 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
| 1067 |
+
elif self.truncation_mode == "keep_end":
|
| 1068 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 1069 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
| 1070 |
+
else:
|
| 1071 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
| 1072 |
+
|
| 1073 |
+
# if that's still too long, truncate the response
|
| 1074 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
| 1075 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 1076 |
+
for k in ["input_ids", "attention_mask"]:
|
| 1077 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
| 1078 |
+
|
| 1079 |
+
# Create labels
|
| 1080 |
+
chosen_sequence_tokens = {
|
| 1081 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 1082 |
+
}
|
| 1083 |
+
rejected_sequence_tokens = {
|
| 1084 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 1085 |
+
}
|
| 1086 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
| 1087 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
| 1088 |
+
self.label_pad_token_id
|
| 1089 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
| 1090 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
| 1091 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
| 1092 |
+
self.label_pad_token_id
|
| 1093 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
| 1094 |
+
|
| 1095 |
+
for k, toks in {
|
| 1096 |
+
"chosen_": chosen_sequence_tokens,
|
| 1097 |
+
"rejected_": rejected_sequence_tokens,
|
| 1098 |
+
"": prompt_tokens,
|
| 1099 |
+
}.items():
|
| 1100 |
+
for type_key, tokens in toks.items():
|
| 1101 |
+
if type_key == "token_type_ids":
|
| 1102 |
+
continue
|
| 1103 |
+
batch[f"{k}{type_key}"] = tokens
|
| 1104 |
+
|
| 1105 |
+
else:
|
| 1106 |
+
chosen_tokens = self.processing_class(
|
| 1107 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 1108 |
+
)
|
| 1109 |
+
rejected_tokens = self.processing_class(
|
| 1110 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 1111 |
+
)
|
| 1112 |
+
prompt_tokens = self.processing_class(
|
| 1113 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
| 1117 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
| 1118 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
| 1119 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
| 1120 |
+
|
| 1121 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
| 1122 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 1123 |
+
labels=torch.tensor(batch["rejected_labels"])
|
| 1124 |
+
)
|
| 1125 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 1126 |
+
labels=torch.tensor(batch["chosen_labels"])
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
if is_torch_xla_available():
|
| 1130 |
+
# Pad the sequences to global max_length to avoid TorchXLA recompilation
|
| 1131 |
+
for k in batch:
|
| 1132 |
+
if "labels" in k or self.is_encoder_decoder:
|
| 1133 |
+
pad_value = self.label_pad_token_id
|
| 1134 |
+
elif k.endswith("_input_ids"):
|
| 1135 |
+
pad_value = self.padding_value
|
| 1136 |
+
elif k.endswith("_attention_mask"):
|
| 1137 |
+
pad_value = 0
|
| 1138 |
+
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
|
| 1139 |
+
return batch
|
| 1140 |
+
|
| 1141 |
+
@staticmethod
|
| 1142 |
+
def concatenated_inputs(
|
| 1143 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1144 |
+
is_encoder_decoder: bool = False,
|
| 1145 |
+
label_pad_token_id: int = -100,
|
| 1146 |
+
padding_value: int = 0,
|
| 1147 |
+
device: Optional[torch.device] = None,
|
| 1148 |
+
) -> dict[str, torch.LongTensor]:
|
| 1149 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
| 1150 |
+
|
| 1151 |
+
Args:
|
| 1152 |
+
batch:
|
| 1153 |
+
A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
|
| 1154 |
+
of shape (batch_size, sequence_length).
|
| 1155 |
+
is_encoder_decoder:
|
| 1156 |
+
Whether the model is an encoder-decoder model.
|
| 1157 |
+
label_pad_token_id:
|
| 1158 |
+
The label pad token id.
|
| 1159 |
+
padding_value:
|
| 1160 |
+
The padding value to use for the concatenated inputs_ids.
|
| 1161 |
+
device:
|
| 1162 |
+
The device for the concatenated inputs.
|
| 1163 |
+
|
| 1164 |
+
Returns:
|
| 1165 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
| 1166 |
+
"""
|
| 1167 |
+
concatenated_batch = {}
|
| 1168 |
+
|
| 1169 |
+
if is_encoder_decoder:
|
| 1170 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
| 1171 |
+
else:
|
| 1172 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
| 1173 |
+
|
| 1174 |
+
for k in batch:
|
| 1175 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
| 1176 |
+
if "labels" in k or is_encoder_decoder:
|
| 1177 |
+
pad_value = label_pad_token_id
|
| 1178 |
+
elif k.endswith("_input_ids"):
|
| 1179 |
+
pad_value = padding_value
|
| 1180 |
+
elif k.endswith("_attention_mask"):
|
| 1181 |
+
pad_value = 0
|
| 1182 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
| 1183 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
| 1184 |
+
for k in batch:
|
| 1185 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
| 1186 |
+
if "labels" in k or is_encoder_decoder:
|
| 1187 |
+
pad_value = label_pad_token_id
|
| 1188 |
+
elif k.endswith("_input_ids"):
|
| 1189 |
+
pad_value = padding_value
|
| 1190 |
+
elif k.endswith("_attention_mask"):
|
| 1191 |
+
pad_value = 0
|
| 1192 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
| 1193 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
| 1194 |
+
(
|
| 1195 |
+
concatenated_batch[concatenated_key],
|
| 1196 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
| 1197 |
+
),
|
| 1198 |
+
dim=0,
|
| 1199 |
+
).to(device=device)
|
| 1200 |
+
|
| 1201 |
+
if is_encoder_decoder:
|
| 1202 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
| 1203 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
| 1204 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
return concatenated_batch
|
| 1208 |
+
|
| 1209 |
+
def odds_ratio_loss(
|
| 1210 |
+
self,
|
| 1211 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 1212 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 1213 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1214 |
+
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
| 1215 |
+
|
| 1216 |
+
Args:
|
| 1217 |
+
policy_chosen_logps:
|
| 1218 |
+
Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
| 1219 |
+
policy_rejected_logps:
|
| 1220 |
+
Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
| 1221 |
+
|
| 1222 |
+
Returns:
|
| 1223 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO
|
| 1224 |
+
loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
|
| 1225 |
+
the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the
|
| 1226 |
+
rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
| 1227 |
+
"""
|
| 1228 |
+
|
| 1229 |
+
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
| 1230 |
+
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
| 1231 |
+
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
| 1232 |
+
)
|
| 1233 |
+
ratio = F.logsigmoid(log_odds)
|
| 1234 |
+
losses = self.beta * ratio
|
| 1235 |
+
|
| 1236 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
| 1237 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
| 1238 |
+
|
| 1239 |
+
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
|
| 1240 |
+
|
| 1241 |
+
@staticmethod
|
| 1242 |
+
def get_batch_logps(
|
| 1243 |
+
logits: torch.FloatTensor,
|
| 1244 |
+
labels: torch.LongTensor,
|
| 1245 |
+
average_log_prob: bool = False,
|
| 1246 |
+
label_pad_token_id: int = -100,
|
| 1247 |
+
is_encoder_decoder: bool = False,
|
| 1248 |
+
) -> torch.FloatTensor:
|
| 1249 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 1250 |
+
|
| 1251 |
+
Args:
|
| 1252 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1253 |
+
labels:
|
| 1254 |
+
Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
|
| 1255 |
+
ignored. Shape: (batch_size, sequence_length)
|
| 1256 |
+
average_log_prob:
|
| 1257 |
+
If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
|
| 1258 |
+
log probabilities of the (non-masked) tokens.
|
| 1259 |
+
label_pad_token_id: The label pad token id.
|
| 1260 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 1261 |
+
|
| 1262 |
+
Returns:
|
| 1263 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
|
| 1264 |
+
given logits.
|
| 1265 |
+
"""
|
| 1266 |
+
if logits.shape[:-1] != labels.shape:
|
| 1267 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1268 |
+
|
| 1269 |
+
if not is_encoder_decoder:
|
| 1270 |
+
labels = labels[:, 1:].clone()
|
| 1271 |
+
logits = logits[:, :-1, :]
|
| 1272 |
+
loss_mask = labels != label_pad_token_id
|
| 1273 |
+
|
| 1274 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 1275 |
+
labels = torch.where(labels == label_pad_token_id, 0, labels)
|
| 1276 |
+
|
| 1277 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 1278 |
+
|
| 1279 |
+
if average_log_prob:
|
| 1280 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1281 |
+
else:
|
| 1282 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 1283 |
+
|
| 1284 |
+
def concatenated_forward(
|
| 1285 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1286 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1287 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
| 1288 |
+
|
| 1289 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
| 1290 |
+
"""
|
| 1291 |
+
concatenated_batch = self.concatenated_inputs(
|
| 1292 |
+
batch,
|
| 1293 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1294 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1295 |
+
padding_value=self.padding_value,
|
| 1296 |
+
device=self.accelerator.device,
|
| 1297 |
+
)
|
| 1298 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
| 1299 |
+
|
| 1300 |
+
model_kwargs = (
|
| 1301 |
+
{
|
| 1302 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
| 1303 |
+
}
|
| 1304 |
+
if self.is_encoder_decoder
|
| 1305 |
+
else {}
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
if self.aux_loss_enabled:
|
| 1309 |
+
model_kwargs["output_router_logits"] = True
|
| 1310 |
+
|
| 1311 |
+
outputs = model(
|
| 1312 |
+
concatenated_batch["concatenated_input_ids"],
|
| 1313 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
| 1314 |
+
use_cache=False,
|
| 1315 |
+
**model_kwargs,
|
| 1316 |
+
)
|
| 1317 |
+
all_logits = outputs.logits
|
| 1318 |
+
|
| 1319 |
+
def cross_entropy_loss(logits, labels):
|
| 1320 |
+
if not self.is_encoder_decoder:
|
| 1321 |
+
# Shift so that tokens < n predict n
|
| 1322 |
+
logits = logits[..., :-1, :].contiguous()
|
| 1323 |
+
labels = labels[..., 1:].contiguous()
|
| 1324 |
+
# Flatten the tokens
|
| 1325 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1326 |
+
logits = logits.view(-1, logits.shape[-1])
|
| 1327 |
+
labels = labels.view(-1)
|
| 1328 |
+
# Enable model parallelism
|
| 1329 |
+
labels = labels.to(logits.device)
|
| 1330 |
+
loss = loss_fct(logits, labels)
|
| 1331 |
+
return loss
|
| 1332 |
+
|
| 1333 |
+
if self.is_encoder_decoder:
|
| 1334 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
| 1335 |
+
else:
|
| 1336 |
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
| 1337 |
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
| 1338 |
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
| 1339 |
+
# orpo chosen nll loss is computed over the full prompt and response
|
| 1340 |
+
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
| 1341 |
+
|
| 1342 |
+
all_logps = self.get_batch_logps(
|
| 1343 |
+
all_logits,
|
| 1344 |
+
concatenated_batch["concatenated_labels"],
|
| 1345 |
+
average_log_prob=True,
|
| 1346 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1347 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1348 |
+
)
|
| 1349 |
+
|
| 1350 |
+
chosen_logps = all_logps[:len_chosen]
|
| 1351 |
+
rejected_logps = all_logps[len_chosen:]
|
| 1352 |
+
|
| 1353 |
+
if not self.is_encoder_decoder:
|
| 1354 |
+
chosen_logits = all_logits[:len_chosen, :-1, :]
|
| 1355 |
+
rejected_logits = all_logits[len_chosen:, :-1, :]
|
| 1356 |
+
else:
|
| 1357 |
+
chosen_logits = all_logits[:len_chosen]
|
| 1358 |
+
rejected_logits = all_logits[len_chosen:]
|
| 1359 |
+
|
| 1360 |
+
if self.aux_loss_enabled:
|
| 1361 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
|
| 1362 |
+
|
| 1363 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
| 1364 |
+
|
| 1365 |
+
def get_batch_loss_metrics(
|
| 1366 |
+
self,
|
| 1367 |
+
model,
|
| 1368 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1369 |
+
train_eval: Literal["train", "eval"] = "train",
|
| 1370 |
+
):
|
| 1371 |
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
| 1372 |
+
metrics = {}
|
| 1373 |
+
|
| 1374 |
+
forward_output = self.concatenated_forward(model, batch)
|
| 1375 |
+
(
|
| 1376 |
+
policy_chosen_logps,
|
| 1377 |
+
policy_rejected_logps,
|
| 1378 |
+
policy_chosen_logits,
|
| 1379 |
+
policy_rejected_logits,
|
| 1380 |
+
policy_nll_loss,
|
| 1381 |
+
) = forward_output[:5]
|
| 1382 |
+
if self.aux_loss_enabled:
|
| 1383 |
+
aux_loss = forward_output[5]
|
| 1384 |
+
|
| 1385 |
+
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
| 1386 |
+
policy_chosen_logps, policy_rejected_logps
|
| 1387 |
+
)
|
| 1388 |
+
# full ORPO loss
|
| 1389 |
+
loss = policy_nll_loss - losses.mean()
|
| 1390 |
+
|
| 1391 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
| 1392 |
+
|
| 1393 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1394 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
|
| 1395 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
|
| 1396 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
|
| 1397 |
+
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
| 1398 |
+
chosen_rewards - rejected_rewards
|
| 1399 |
+
).mean()
|
| 1400 |
+
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
| 1401 |
+
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
| 1402 |
+
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
| 1403 |
+
policy_rejected_logits.detach().mean()
|
| 1404 |
+
).mean()
|
| 1405 |
+
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
| 1406 |
+
policy_chosen_logits.detach().mean()
|
| 1407 |
+
).mean()
|
| 1408 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
| 1409 |
+
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
| 1410 |
+
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
| 1411 |
+
if is_torch_xla_available():
|
| 1412 |
+
xm.mark_step() # needed because .item() calls
|
| 1413 |
+
for k, v in metrics.items():
|
| 1414 |
+
metrics[k] = v.item()
|
| 1415 |
+
if self.aux_loss_enabled:
|
| 1416 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1417 |
+
|
| 1418 |
+
return loss, metrics
|
| 1419 |
+
|
| 1420 |
+
def compute_loss(
|
| 1421 |
+
self,
|
| 1422 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1423 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1424 |
+
return_outputs=False,
|
| 1425 |
+
num_items_in_batch=None,
|
| 1426 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1427 |
+
compute_loss_context_manager = (
|
| 1428 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
with compute_loss_context_manager:
|
| 1432 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
| 1433 |
+
|
| 1434 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1435 |
+
loss = loss.to(self.args.device)
|
| 1436 |
+
|
| 1437 |
+
# force log the metrics
|
| 1438 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1439 |
+
|
| 1440 |
+
if return_outputs:
|
| 1441 |
+
return (loss, metrics)
|
| 1442 |
+
return loss
|
| 1443 |
+
|
| 1444 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
| 1445 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1446 |
+
|
| 1447 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1448 |
+
# the torch amp context manager as some hidden states are silently casted to full precision.
|
| 1449 |
+
generate_context_manager = (
|
| 1450 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
+
with generate_context_manager:
|
| 1454 |
+
policy_output = model.generate(
|
| 1455 |
+
input_ids=batch["prompt_input_ids"],
|
| 1456 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1457 |
+
max_length=self.max_length,
|
| 1458 |
+
do_sample=True,
|
| 1459 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1460 |
+
)
|
| 1461 |
+
|
| 1462 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1463 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1464 |
+
|
| 1465 |
+
return policy_output_decoded
|
| 1466 |
+
|
| 1467 |
+
def prediction_step(
|
| 1468 |
+
self,
|
| 1469 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1470 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1471 |
+
prediction_loss_only: bool,
|
| 1472 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1473 |
+
):
|
| 1474 |
+
if not self.use_dpo_data_collator:
|
| 1475 |
+
logger.warning(
|
| 1476 |
+
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
| 1477 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
| 1478 |
+
)
|
| 1479 |
+
if ignore_keys is None:
|
| 1480 |
+
if hasattr(model, "config"):
|
| 1481 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1482 |
+
else:
|
| 1483 |
+
ignore_keys = []
|
| 1484 |
+
|
| 1485 |
+
prediction_context_manager = (
|
| 1486 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1487 |
+
)
|
| 1488 |
+
|
| 1489 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1490 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
| 1491 |
+
|
| 1492 |
+
# force log the metrics
|
| 1493 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1494 |
+
|
| 1495 |
+
if prediction_loss_only:
|
| 1496 |
+
return (loss.detach(), None, None)
|
| 1497 |
+
|
| 1498 |
+
# logits for the chosen and rejected samples from model
|
| 1499 |
+
logits_dict = {
|
| 1500 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
| 1501 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
| 1502 |
+
}
|
| 1503 |
+
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 1504 |
+
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 1505 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1506 |
+
|
| 1507 |
+
return (loss.detach(), logits, labels)
|
| 1508 |
+
|
| 1509 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1510 |
+
for key, value in metrics.items():
|
| 1511 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1512 |
+
|
| 1513 |
+
def evaluation_loop(
|
| 1514 |
+
self,
|
| 1515 |
+
dataloader: DataLoader,
|
| 1516 |
+
description: str,
|
| 1517 |
+
prediction_loss_only: Optional[bool] = None,
|
| 1518 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1519 |
+
metric_key_prefix: str = "eval",
|
| 1520 |
+
) -> EvalLoopOutput:
|
| 1521 |
+
"""
|
| 1522 |
+
Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
|
| 1523 |
+
`Trainer.evaluate()` and `Trainer.predict()`.
|
| 1524 |
+
|
| 1525 |
+
Works both with or without labels.
|
| 1526 |
+
"""
|
| 1527 |
+
|
| 1528 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1529 |
+
if self.generate_during_eval:
|
| 1530 |
+
# Generate random indices within the range of the total number of samples
|
| 1531 |
+
num_samples = len(dataloader.dataset)
|
| 1532 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1533 |
+
|
| 1534 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1535 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1536 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1537 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1538 |
+
|
| 1539 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
| 1540 |
+
|
| 1541 |
+
table = pd.DataFrame(
|
| 1542 |
+
columns=["Prompt", "Policy"],
|
| 1543 |
+
data=[
|
| 1544 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
| 1545 |
+
],
|
| 1546 |
+
)
|
| 1547 |
+
if "wandb" in self.args.report_to:
|
| 1548 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1549 |
+
|
| 1550 |
+
if "comet_ml" in self.args.report_to:
|
| 1551 |
+
log_table_to_comet_experiment(
|
| 1552 |
+
name="game_log.csv",
|
| 1553 |
+
table=table,
|
| 1554 |
+
)
|
| 1555 |
+
|
| 1556 |
+
# Base evaluation
|
| 1557 |
+
initial_output = super().evaluation_loop(
|
| 1558 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1559 |
+
)
|
| 1560 |
+
|
| 1561 |
+
return initial_output
|
| 1562 |
+
|
| 1563 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1564 |
+
"""
|
| 1565 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1566 |
+
|
| 1567 |
+
Args:
|
| 1568 |
+
logs (`dict[str, float]`):
|
| 1569 |
+
The values to log.
|
| 1570 |
+
start_time (`float`, *optional*):
|
| 1571 |
+
Start time of the training.
|
| 1572 |
+
"""
|
| 1573 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1574 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1575 |
+
# Add averaged stored metrics to logs
|
| 1576 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1577 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
| 1578 |
+
del self._stored_metrics[train_eval]
|
| 1579 |
+
return super().log(logs, start_time)
|
| 1580 |
+
|
| 1581 |
+
def _shift_right(self, input_ids):
|
| 1582 |
+
if self.decoder_start_token_id is None:
|
| 1583 |
+
raise ValueError(
|
| 1584 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
| 1585 |
+
)
|
| 1586 |
+
|
| 1587 |
+
# shift inputs to the right
|
| 1588 |
+
if is_torch_fx_proxy(input_ids):
|
| 1589 |
+
# Item assignment is not supported natively for proxies.
|
| 1590 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
| 1591 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 1592 |
+
else:
|
| 1593 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 1594 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 1595 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
| 1596 |
+
|
| 1597 |
+
if self.pad_token_id is None:
|
| 1598 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
| 1599 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 1600 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 1601 |
+
|
| 1602 |
+
return shifted_input_ids
|
| 1603 |
+
|
| 1604 |
+
# Ensure the model card is saved along with the checkpoint
|
| 1605 |
+
def _save_checkpoint(self, model, trial):
|
| 1606 |
+
if self.args.hub_model_id is None:
|
| 1607 |
+
model_name = Path(self.args.output_dir).name
|
| 1608 |
+
else:
|
| 1609 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 1610 |
+
self.create_model_card(model_name=model_name)
|
| 1611 |
+
super()._save_checkpoint(model, trial)
|
| 1612 |
+
class UnslothORPOTrainer(_UnslothORPOTrainer):
|
| 1613 |
+
"""
|
| 1614 |
+
|
| 1615 |
+
Initialize ORPOTrainer.
|
| 1616 |
+
|
| 1617 |
+
Args:
|
| 1618 |
+
model ([`~transformers.PreTrainedModel`]):
|
| 1619 |
+
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
| 1620 |
+
args ([`ORPOConfig`]):
|
| 1621 |
+
The ORPO config arguments to use for training.
|
| 1622 |
+
data_collator ([`~transformers.DataCollator`]):
|
| 1623 |
+
The data collator to use for training. If None is specified, the default data collator
|
| 1624 |
+
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
| 1625 |
+
sequences in the batch, given a dataset of paired sequences.
|
| 1626 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 1627 |
+
The dataset to use for training.
|
| 1628 |
+
eval_dataset ([`~datasets.Dataset`]):
|
| 1629 |
+
The dataset to use for evaluation.
|
| 1630 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 1631 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1632 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1633 |
+
reuse the fine-tuned model.
|
| 1634 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1635 |
+
The model initializer to use for training. If None is specified, the default model initializer will be
|
| 1636 |
+
used.
|
| 1637 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1638 |
+
The callbacks to use for training.
|
| 1639 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1640 |
+
The optimizer and scheduler to use for training.
|
| 1641 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1642 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1643 |
+
peft_config (`dict`, defaults to `None`):
|
| 1644 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
|
| 1645 |
+
a PEFT model.
|
| 1646 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1647 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
|
| 1648 |
+
metric values.
|
| 1649 |
+
|
| 1650 |
+
"""
|
| 1651 |
+
def __init__(
|
| 1652 |
+
self,
|
| 1653 |
+
model = None,
|
| 1654 |
+
args = None,
|
| 1655 |
+
data_collator = None,
|
| 1656 |
+
train_dataset = None,
|
| 1657 |
+
eval_dataset = None,
|
| 1658 |
+
processing_class = None,
|
| 1659 |
+
model_init = None,
|
| 1660 |
+
callbacks = None,
|
| 1661 |
+
preprocess_logits_for_metrics = None,
|
| 1662 |
+
peft_config = None,
|
| 1663 |
+
compute_metrics = None,
|
| 1664 |
+
**kwargs
|
| 1665 |
+
):
|
| 1666 |
+
if args is None: args = UnslothORPOConfig()
|
| 1667 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1668 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1669 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1670 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1671 |
+
force_float32 = False
|
| 1672 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1673 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1674 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1675 |
+
force_float32 = True
|
| 1676 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1677 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1678 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1679 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1680 |
+
dtype = _get_dtype(dtype)
|
| 1681 |
+
float16 = dtype == torch.float16
|
| 1682 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1683 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1684 |
+
if force_float32:
|
| 1685 |
+
# Forced float32 training
|
| 1686 |
+
args.fp16 = False
|
| 1687 |
+
args.bf16 = False
|
| 1688 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1689 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1690 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1691 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1692 |
+
# Mixed precision training
|
| 1693 |
+
args.fp16 = float16
|
| 1694 |
+
args.bf16 = not float16
|
| 1695 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1696 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1697 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1698 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1699 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1700 |
+
args.fp16 = False
|
| 1701 |
+
args.bf16 = False
|
| 1702 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1703 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1704 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1705 |
+
|
| 1706 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1707 |
+
args.eval_strategy = 'steps'
|
| 1708 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1709 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1710 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1711 |
+
from transformers import __version__ as transformers_version
|
| 1712 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1713 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1714 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1715 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1716 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1717 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1718 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1719 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1720 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1721 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1722 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1723 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1724 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1725 |
+
if force_float32:
|
| 1726 |
+
args.bf16_full_eval = False
|
| 1727 |
+
args.fp16_full_eval = False
|
| 1728 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1729 |
+
args.bf16_full_eval = True
|
| 1730 |
+
args.fp16_full_eval = False
|
| 1731 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1732 |
+
args.bf16_full_eval = args.bf16
|
| 1733 |
+
args.fp16_full_eval = args.fp16
|
| 1734 |
+
_output_logits = False
|
| 1735 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1736 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1737 |
+
if _output_logits:
|
| 1738 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1739 |
+
if model is not None:
|
| 1740 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1741 |
+
if _warnings_issued is None:
|
| 1742 |
+
model.warnings_issued = {}
|
| 1743 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1744 |
+
try:
|
| 1745 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1746 |
+
except Exception:
|
| 1747 |
+
model.warnings_issued = {}
|
| 1748 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1749 |
+
pass
|
| 1750 |
+
else:
|
| 1751 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1752 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1753 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1754 |
+
max_seq_length = model.max_seq_length
|
| 1755 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1756 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1757 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1758 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1759 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1760 |
+
args.max_seq_length = model_max_seq_length
|
| 1761 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1762 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1763 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1764 |
+
if 'processing_class' in locals():
|
| 1765 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1766 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1767 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1768 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1769 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1770 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1771 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1772 |
+
__tokenizer,
|
| 1773 |
+
mlm = False,
|
| 1774 |
+
mlm_probability = 0.0,
|
| 1775 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1776 |
+
)
|
| 1777 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1778 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1779 |
+
__tokenizer,
|
| 1780 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1781 |
+
)
|
| 1782 |
+
else:
|
| 1783 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1784 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1785 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1786 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1787 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1788 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1789 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1790 |
+
__tokenizer.tokenizer,
|
| 1791 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1792 |
+
)
|
| 1793 |
+
else:
|
| 1794 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1795 |
+
__tokenizer.tokenizer,
|
| 1796 |
+
mlm = False,
|
| 1797 |
+
mlm_probability = 0.0,
|
| 1798 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1799 |
+
)
|
| 1800 |
+
other_metrics = []
|
| 1801 |
+
|
| 1802 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1803 |
+
PatchRLStatistics('orpo_trainer', other_metrics)
|
| 1804 |
+
|
| 1805 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1806 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1807 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1808 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1809 |
+
args._n_gpu = 1
|
| 1810 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1811 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1812 |
+
super().__init__(
|
| 1813 |
+
model = model,
|
| 1814 |
+
args = args,
|
| 1815 |
+
data_collator = data_collator,
|
| 1816 |
+
train_dataset = train_dataset,
|
| 1817 |
+
eval_dataset = eval_dataset,
|
| 1818 |
+
processing_class = processing_class,
|
| 1819 |
+
model_init = model_init,
|
| 1820 |
+
callbacks = callbacks,
|
| 1821 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1822 |
+
peft_config = peft_config,
|
| 1823 |
+
compute_metrics = compute_metrics,**kwargs)
|
| 1824 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1825 |
+
model.for_inference()
|
| 1826 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1827 |
+
self.neftune_hook_handle.remove()
|
| 1828 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1829 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1830 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1831 |
+
pass
|
| 1832 |
+
if hasattr(self, 'accelerator'):
|
| 1833 |
+
scaler = self.accelerator.scaler
|
| 1834 |
+
current_model = model
|
| 1835 |
+
while hasattr(current_model, 'model'):
|
| 1836 |
+
current_model.accelerator_scaler = scaler
|
| 1837 |
+
current_model = current_model.model
|
| 1838 |
+
current_model.accelerator_scaler = scaler
|
| 1839 |
+
pass
|
| 1840 |
+
if hasattr(self, 'train'):
|
| 1841 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1842 |
+
pass
|
| 1843 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1844 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1845 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1846 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1847 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1848 |
+
pass
|
| 1849 |
+
|
| 1850 |
+
pass
|
| 1851 |
+
|
| 1852 |
+
|
| 1853 |
+
if hasattr(logger, "addFilter"):
|
| 1854 |
+
import logging
|
| 1855 |
+
class HideLoggingMessage(logging.Filter):
|
| 1856 |
+
def __init__(self, text): self.text = text
|
| 1857 |
+
def filter(self, x): return not (self.text in x.getMessage())
|
| 1858 |
+
pass
|
| 1859 |
+
logger.addFilter(HideLoggingMessage("`use_cache=True`"))
|
| 1860 |
+
|
unsloth_compiled_cache/UnslothPPOTrainer.py
ADDED
|
@@ -0,0 +1,1634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, BaseTrainer, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, Path, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, empty_cache, exact_div, first_true_indices, forward, gather_object, gc, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_rich_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, warnings, Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, OnlineTrainerState, Optional, PPOConfig, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, broadcast, create_reference_model, disable_dropout_in_model, exact_div, forward, get_peft_model, get_reporting_integration_callbacks, is_peft_available, math, nn, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, time, torch, warnings, PeftModel, is_peft_available, os, torch)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothPPOConfig(PPOConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`PPOTrainer`].
|
| 328 |
+
|
| 329 |
+
This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
|
| 330 |
+
please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
|
| 331 |
+
values in this class may differ from those in [`~transformers.TrainingArguments`].
|
| 332 |
+
|
| 333 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 334 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 335 |
+
command line.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
| 339 |
+
Name of this experiment.
|
| 340 |
+
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
| 341 |
+
Path to the reward model.
|
| 342 |
+
model_adapter_name (`str`, *optional*):
|
| 343 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 344 |
+
ref_adapter_name (`str`, *optional*):
|
| 345 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 346 |
+
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
| 347 |
+
Number of epochs to train.
|
| 348 |
+
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
| 349 |
+
Whether to whiten the rewards.
|
| 350 |
+
kl_coef (`float`, *optional*, defaults to `0.05`):
|
| 351 |
+
KL coefficient.
|
| 352 |
+
kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
|
| 353 |
+
Which estimator for KL-Divergence to use from [Approximating KL
|
| 354 |
+
Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
|
| 355 |
+
estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
|
| 356 |
+
better estimator". Cannot be set to "k2", as it is used for logging purposes.
|
| 357 |
+
cliprange (`float`, *optional*, defaults to `0.2`):
|
| 358 |
+
Clip range.
|
| 359 |
+
vf_coef (`float`, *optional*, defaults to `0.1`):
|
| 360 |
+
Value function coefficient.
|
| 361 |
+
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
| 362 |
+
Clip range for the value function.
|
| 363 |
+
gamma (`float`, *optional*, defaults to `1.0`):
|
| 364 |
+
Discount factor.
|
| 365 |
+
lam (`float`, *optional*, defaults to `0.95`):
|
| 366 |
+
Lambda value for GAE.
|
| 367 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 368 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 369 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 370 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
| 371 |
+
|
| 372 |
+
"""
|
| 373 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 374 |
+
default = None,
|
| 375 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 376 |
+
)
|
| 377 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 378 |
+
default = -1,
|
| 379 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 380 |
+
)
|
| 381 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 382 |
+
default = None,
|
| 383 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 384 |
+
)
|
| 385 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 386 |
+
default = None,
|
| 387 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def __init__(
|
| 391 |
+
self,
|
| 392 |
+
output_dir = None,
|
| 393 |
+
per_device_train_batch_size = 4,
|
| 394 |
+
num_train_epochs = 3.0,
|
| 395 |
+
max_steps = -1,
|
| 396 |
+
learning_rate = 5e-05,
|
| 397 |
+
lr_scheduler_type = 'linear',
|
| 398 |
+
lr_scheduler_kwargs = None,
|
| 399 |
+
warmup_steps = 0.1,
|
| 400 |
+
optim = 'adamw_8bit',
|
| 401 |
+
optim_args = None,
|
| 402 |
+
weight_decay = 0.01,
|
| 403 |
+
adam_beta1 = 0.9,
|
| 404 |
+
adam_beta2 = 0.999,
|
| 405 |
+
adam_epsilon = 1e-08,
|
| 406 |
+
optim_target_modules = None,
|
| 407 |
+
gradient_accumulation_steps = 2,
|
| 408 |
+
average_tokens_across_devices = True,
|
| 409 |
+
max_grad_norm = 1.0,
|
| 410 |
+
label_smoothing_factor = 0.0,
|
| 411 |
+
bf16 = False,
|
| 412 |
+
fp16 = False,
|
| 413 |
+
bf16_full_eval = False,
|
| 414 |
+
fp16_full_eval = False,
|
| 415 |
+
tf32 = None,
|
| 416 |
+
gradient_checkpointing = True,
|
| 417 |
+
gradient_checkpointing_kwargs = None,
|
| 418 |
+
torch_compile = False,
|
| 419 |
+
torch_compile_backend = None,
|
| 420 |
+
torch_compile_mode = None,
|
| 421 |
+
use_liger_kernel = False,
|
| 422 |
+
liger_kernel_config = None,
|
| 423 |
+
use_cache = False,
|
| 424 |
+
neftune_noise_alpha = None,
|
| 425 |
+
torch_empty_cache_steps = 250,
|
| 426 |
+
auto_find_batch_size = False,
|
| 427 |
+
logging_strategy = 'steps',
|
| 428 |
+
logging_steps = 1,
|
| 429 |
+
logging_first_step = False,
|
| 430 |
+
log_on_each_node = True,
|
| 431 |
+
logging_nan_inf_filter = False,
|
| 432 |
+
include_num_input_tokens_seen = False,
|
| 433 |
+
log_level = 'passive',
|
| 434 |
+
log_level_replica = 'warning',
|
| 435 |
+
disable_tqdm = None,
|
| 436 |
+
report_to = 'none',
|
| 437 |
+
run_name = None,
|
| 438 |
+
project = 'huggingface',
|
| 439 |
+
trackio_space_id = 'trackio',
|
| 440 |
+
eval_strategy = 'no',
|
| 441 |
+
eval_steps = None,
|
| 442 |
+
eval_delay = 0,
|
| 443 |
+
per_device_eval_batch_size = 4,
|
| 444 |
+
prediction_loss_only = False,
|
| 445 |
+
eval_on_start = False,
|
| 446 |
+
eval_do_concat_batches = True,
|
| 447 |
+
eval_use_gather_object = False,
|
| 448 |
+
eval_accumulation_steps = 2,
|
| 449 |
+
batch_eval_metrics = False,
|
| 450 |
+
save_only_model = False,
|
| 451 |
+
save_strategy = 'steps',
|
| 452 |
+
save_steps = 500,
|
| 453 |
+
save_on_each_node = False,
|
| 454 |
+
save_total_limit = None,
|
| 455 |
+
enable_jit_checkpoint = False,
|
| 456 |
+
push_to_hub = False,
|
| 457 |
+
hub_token = None,
|
| 458 |
+
hub_private_repo = None,
|
| 459 |
+
hub_model_id = None,
|
| 460 |
+
hub_strategy = 'every_save',
|
| 461 |
+
hub_always_push = False,
|
| 462 |
+
hub_revision = None,
|
| 463 |
+
load_best_model_at_end = False,
|
| 464 |
+
metric_for_best_model = None,
|
| 465 |
+
greater_is_better = None,
|
| 466 |
+
ignore_data_skip = False,
|
| 467 |
+
restore_callback_states_from_checkpoint = False,
|
| 468 |
+
full_determinism = False,
|
| 469 |
+
seed = 3407,
|
| 470 |
+
data_seed = 3407,
|
| 471 |
+
use_cpu = False,
|
| 472 |
+
accelerator_config = None,
|
| 473 |
+
parallelism_config = None,
|
| 474 |
+
dataloader_drop_last = False,
|
| 475 |
+
dataloader_num_workers = 0,
|
| 476 |
+
dataloader_pin_memory = True,
|
| 477 |
+
dataloader_persistent_workers = False,
|
| 478 |
+
dataloader_prefetch_factor = None,
|
| 479 |
+
remove_unused_columns = True,
|
| 480 |
+
label_names = None,
|
| 481 |
+
train_sampling_strategy = 'random',
|
| 482 |
+
length_column_name = 'length',
|
| 483 |
+
ddp_find_unused_parameters = None,
|
| 484 |
+
ddp_bucket_cap_mb = None,
|
| 485 |
+
ddp_broadcast_buffers = None,
|
| 486 |
+
ddp_backend = None,
|
| 487 |
+
ddp_timeout = 1800,
|
| 488 |
+
fsdp = None,
|
| 489 |
+
fsdp_config = None,
|
| 490 |
+
deepspeed = None,
|
| 491 |
+
debug = '',
|
| 492 |
+
skip_memory_metrics = True,
|
| 493 |
+
do_train = False,
|
| 494 |
+
do_eval = False,
|
| 495 |
+
do_predict = False,
|
| 496 |
+
resume_from_checkpoint = None,
|
| 497 |
+
warmup_ratio = None,
|
| 498 |
+
logging_dir = None,
|
| 499 |
+
local_rank = -1,
|
| 500 |
+
dataset_num_proc = None,
|
| 501 |
+
num_mini_batches = 1,
|
| 502 |
+
total_episodes = None,
|
| 503 |
+
local_rollout_forward_batch_size = 64,
|
| 504 |
+
num_sample_generations = 10,
|
| 505 |
+
response_length = 53,
|
| 506 |
+
stop_token = None,
|
| 507 |
+
stop_token_id = None,
|
| 508 |
+
temperature = 0.7,
|
| 509 |
+
missing_eos_penalty = None,
|
| 510 |
+
sft_model_path = 'EleutherAI/pythia-160m',
|
| 511 |
+
world_size = None,
|
| 512 |
+
num_total_batches = None,
|
| 513 |
+
micro_batch_size = None,
|
| 514 |
+
local_batch_size = None,
|
| 515 |
+
batch_size = None,
|
| 516 |
+
local_mini_batch_size = None,
|
| 517 |
+
mini_batch_size = None,
|
| 518 |
+
exp_name = 'ppo_config',
|
| 519 |
+
reward_model_path = 'EleutherAI/pythia-160m',
|
| 520 |
+
model_adapter_name = None,
|
| 521 |
+
ref_adapter_name = None,
|
| 522 |
+
num_ppo_epochs = 4,
|
| 523 |
+
whiten_rewards = False,
|
| 524 |
+
kl_coef = 0.05,
|
| 525 |
+
kl_estimator = 'k1',
|
| 526 |
+
cliprange = 0.2,
|
| 527 |
+
vf_coef = 0.1,
|
| 528 |
+
cliprange_value = 0.2,
|
| 529 |
+
gamma = 1.0,
|
| 530 |
+
lam = 0.95,
|
| 531 |
+
ds3_gather_for_generation = True,
|
| 532 |
+
vllm_sampling_params = None,
|
| 533 |
+
unsloth_num_chunks = -1,
|
| 534 |
+
unsloth_logit_chunk_multiplier = None,
|
| 535 |
+
unsloth_grpo_mini_batch = None,
|
| 536 |
+
|
| 537 |
+
**kwargs,
|
| 538 |
+
):
|
| 539 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 540 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 541 |
+
if num_train_epochs is None:
|
| 542 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 543 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 544 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 545 |
+
save_strategy = 'no'
|
| 546 |
+
import multiprocessing as _mp
|
| 547 |
+
if _mp.get_start_method() != 'fork':
|
| 548 |
+
dataset_num_proc = None
|
| 549 |
+
elif dataset_num_proc is None:
|
| 550 |
+
import psutil
|
| 551 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 552 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 553 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 554 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 555 |
+
if temperature <= 0:
|
| 556 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 557 |
+
elif temperature >= 10:
|
| 558 |
+
raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
super().__init__(
|
| 562 |
+
output_dir = output_dir,
|
| 563 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 564 |
+
num_train_epochs = num_train_epochs,
|
| 565 |
+
max_steps = max_steps,
|
| 566 |
+
learning_rate = learning_rate,
|
| 567 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 568 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 569 |
+
warmup_steps = warmup_steps,
|
| 570 |
+
optim = optim,
|
| 571 |
+
optim_args = optim_args,
|
| 572 |
+
weight_decay = weight_decay,
|
| 573 |
+
adam_beta1 = adam_beta1,
|
| 574 |
+
adam_beta2 = adam_beta2,
|
| 575 |
+
adam_epsilon = adam_epsilon,
|
| 576 |
+
optim_target_modules = optim_target_modules,
|
| 577 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 578 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 579 |
+
max_grad_norm = max_grad_norm,
|
| 580 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 581 |
+
bf16 = bf16,
|
| 582 |
+
fp16 = fp16,
|
| 583 |
+
bf16_full_eval = bf16_full_eval,
|
| 584 |
+
fp16_full_eval = fp16_full_eval,
|
| 585 |
+
tf32 = tf32,
|
| 586 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 587 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 588 |
+
torch_compile = torch_compile,
|
| 589 |
+
torch_compile_backend = torch_compile_backend,
|
| 590 |
+
torch_compile_mode = torch_compile_mode,
|
| 591 |
+
use_liger_kernel = use_liger_kernel,
|
| 592 |
+
liger_kernel_config = liger_kernel_config,
|
| 593 |
+
use_cache = use_cache,
|
| 594 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 595 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 596 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 597 |
+
logging_strategy = logging_strategy,
|
| 598 |
+
logging_steps = logging_steps,
|
| 599 |
+
logging_first_step = logging_first_step,
|
| 600 |
+
log_on_each_node = log_on_each_node,
|
| 601 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 602 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 603 |
+
log_level = log_level,
|
| 604 |
+
log_level_replica = log_level_replica,
|
| 605 |
+
disable_tqdm = disable_tqdm,
|
| 606 |
+
report_to = report_to,
|
| 607 |
+
run_name = run_name,
|
| 608 |
+
project = project,
|
| 609 |
+
trackio_space_id = trackio_space_id,
|
| 610 |
+
eval_strategy = eval_strategy,
|
| 611 |
+
eval_steps = eval_steps,
|
| 612 |
+
eval_delay = eval_delay,
|
| 613 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 614 |
+
prediction_loss_only = prediction_loss_only,
|
| 615 |
+
eval_on_start = eval_on_start,
|
| 616 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 617 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 618 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 619 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 620 |
+
save_only_model = save_only_model,
|
| 621 |
+
save_strategy = save_strategy,
|
| 622 |
+
save_steps = save_steps,
|
| 623 |
+
save_on_each_node = save_on_each_node,
|
| 624 |
+
save_total_limit = save_total_limit,
|
| 625 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 626 |
+
push_to_hub = push_to_hub,
|
| 627 |
+
hub_token = hub_token,
|
| 628 |
+
hub_private_repo = hub_private_repo,
|
| 629 |
+
hub_model_id = hub_model_id,
|
| 630 |
+
hub_strategy = hub_strategy,
|
| 631 |
+
hub_always_push = hub_always_push,
|
| 632 |
+
hub_revision = hub_revision,
|
| 633 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 634 |
+
metric_for_best_model = metric_for_best_model,
|
| 635 |
+
greater_is_better = greater_is_better,
|
| 636 |
+
ignore_data_skip = ignore_data_skip,
|
| 637 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 638 |
+
full_determinism = full_determinism,
|
| 639 |
+
seed = seed,
|
| 640 |
+
data_seed = data_seed,
|
| 641 |
+
use_cpu = use_cpu,
|
| 642 |
+
accelerator_config = accelerator_config,
|
| 643 |
+
parallelism_config = parallelism_config,
|
| 644 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 645 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 646 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 647 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 648 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 649 |
+
remove_unused_columns = remove_unused_columns,
|
| 650 |
+
label_names = label_names,
|
| 651 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 652 |
+
length_column_name = length_column_name,
|
| 653 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 654 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 655 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 656 |
+
ddp_backend = ddp_backend,
|
| 657 |
+
ddp_timeout = ddp_timeout,
|
| 658 |
+
fsdp = fsdp,
|
| 659 |
+
fsdp_config = fsdp_config,
|
| 660 |
+
deepspeed = deepspeed,
|
| 661 |
+
debug = debug,
|
| 662 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 663 |
+
do_train = do_train,
|
| 664 |
+
do_eval = do_eval,
|
| 665 |
+
do_predict = do_predict,
|
| 666 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 667 |
+
warmup_ratio = warmup_ratio,
|
| 668 |
+
logging_dir = logging_dir,
|
| 669 |
+
local_rank = local_rank,
|
| 670 |
+
dataset_num_proc = dataset_num_proc,
|
| 671 |
+
num_mini_batches = num_mini_batches,
|
| 672 |
+
total_episodes = total_episodes,
|
| 673 |
+
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
| 674 |
+
num_sample_generations = num_sample_generations,
|
| 675 |
+
response_length = response_length,
|
| 676 |
+
stop_token = stop_token,
|
| 677 |
+
stop_token_id = stop_token_id,
|
| 678 |
+
temperature = temperature,
|
| 679 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 680 |
+
sft_model_path = sft_model_path,
|
| 681 |
+
world_size = world_size,
|
| 682 |
+
num_total_batches = num_total_batches,
|
| 683 |
+
micro_batch_size = micro_batch_size,
|
| 684 |
+
local_batch_size = local_batch_size,
|
| 685 |
+
batch_size = batch_size,
|
| 686 |
+
local_mini_batch_size = local_mini_batch_size,
|
| 687 |
+
mini_batch_size = mini_batch_size,
|
| 688 |
+
exp_name = exp_name,
|
| 689 |
+
reward_model_path = reward_model_path,
|
| 690 |
+
model_adapter_name = model_adapter_name,
|
| 691 |
+
ref_adapter_name = ref_adapter_name,
|
| 692 |
+
num_ppo_epochs = num_ppo_epochs,
|
| 693 |
+
whiten_rewards = whiten_rewards,
|
| 694 |
+
kl_coef = kl_coef,
|
| 695 |
+
kl_estimator = kl_estimator,
|
| 696 |
+
cliprange = cliprange,
|
| 697 |
+
vf_coef = vf_coef,
|
| 698 |
+
cliprange_value = cliprange_value,
|
| 699 |
+
gamma = gamma,
|
| 700 |
+
lam = lam,
|
| 701 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 702 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 703 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 704 |
+
if unsloth_grpo_mini_batch is not None:
|
| 705 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 706 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 707 |
+
else:
|
| 708 |
+
raise ValueError(
|
| 709 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 710 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 711 |
+
)
|
| 712 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
pass
|
| 716 |
+
|
| 717 |
+
class _UnslothPPOTrainer(BaseTrainer):
|
| 718 |
+
""""""
|
| 719 |
+
|
| 720 |
+
_tag_names = ["trl", "ppo"]
|
| 721 |
+
_name = "PPO"
|
| 722 |
+
_paper = {
|
| 723 |
+
"title": "Fine-Tuning Language Models from Human Preferences",
|
| 724 |
+
"id": "1909.08593",
|
| 725 |
+
# docstyle-ignore
|
| 726 |
+
"citation": textwrap.dedent("""\
|
| 727 |
+
@article{mziegler2019fine-tuning,
|
| 728 |
+
title = {{Fine-Tuning Language Models from Human Preferences}},
|
| 729 |
+
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
|
| 730 |
+
year = 2019,
|
| 731 |
+
eprint = {arXiv:1909.08593}
|
| 732 |
+
}"""),
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
def __init__(
|
| 736 |
+
self,
|
| 737 |
+
args: PPOConfig,
|
| 738 |
+
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
|
| 739 |
+
model: nn.Module,
|
| 740 |
+
ref_model: Optional[nn.Module],
|
| 741 |
+
reward_model: nn.Module,
|
| 742 |
+
train_dataset: Dataset,
|
| 743 |
+
value_model: nn.Module,
|
| 744 |
+
data_collator: Optional[DataCollatorWithPadding] = None,
|
| 745 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 746 |
+
# less commonly used
|
| 747 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 748 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 749 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 750 |
+
) -> None:
|
| 751 |
+
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
| 752 |
+
warnings.warn(
|
| 753 |
+
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
| 754 |
+
"it and want it to remain, please share your comments here: "
|
| 755 |
+
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
| 756 |
+
"TRL_EXPERIMENTAL_SILENCE=1."
|
| 757 |
+
)
|
| 758 |
+
if ref_model is model:
|
| 759 |
+
raise ValueError(
|
| 760 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 761 |
+
"same as `model`, you must make a copy of it, or `None` if you use peft."
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
self.args = args
|
| 765 |
+
self.processing_class = processing_class
|
| 766 |
+
self.policy_model = model
|
| 767 |
+
|
| 768 |
+
# Define the collator if not provided
|
| 769 |
+
if data_collator is None:
|
| 770 |
+
data_collator = DataCollatorWithPadding(self.processing_class)
|
| 771 |
+
|
| 772 |
+
# Handle stop token settings: update policy model's generation_config to use provided stop token
|
| 773 |
+
if args.stop_token and args.stop_token_id:
|
| 774 |
+
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
|
| 775 |
+
elif args.stop_token:
|
| 776 |
+
if args.stop_token == "eos":
|
| 777 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
|
| 778 |
+
else:
|
| 779 |
+
raise ValueError(
|
| 780 |
+
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
|
| 781 |
+
)
|
| 782 |
+
else:
|
| 783 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
|
| 784 |
+
|
| 785 |
+
# Check that the kl estimator is valid
|
| 786 |
+
if self.args.kl_estimator not in {"k1", "k3"}:
|
| 787 |
+
raise ValueError(
|
| 788 |
+
"kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
|
| 789 |
+
"appears to be a strictly better estimator). See "
|
| 790 |
+
"[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# peft support
|
| 794 |
+
if not is_peft_available() and peft_config is not None:
|
| 795 |
+
raise ImportError(
|
| 796 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 797 |
+
)
|
| 798 |
+
elif is_peft_available() and peft_config is not None:
|
| 799 |
+
# if model is a peft model and we have a peft_confg, we merge and unload it first
|
| 800 |
+
if isinstance(self.policy_model, PeftModel):
|
| 801 |
+
self.policy_model = self.policy_model.merge_and_unload()
|
| 802 |
+
|
| 803 |
+
# get peft model with the given config
|
| 804 |
+
self.policy_model = get_peft_model(self.policy_model, peft_config)
|
| 805 |
+
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
|
| 806 |
+
peft_module_casting_to_bf16(self.policy_model)
|
| 807 |
+
|
| 808 |
+
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
|
| 809 |
+
self.model_adapter_name = args.model_adapter_name
|
| 810 |
+
self.ref_adapter_name = args.ref_adapter_name
|
| 811 |
+
|
| 812 |
+
if ref_model:
|
| 813 |
+
self.ref_model = ref_model
|
| 814 |
+
elif self.is_peft_model:
|
| 815 |
+
self.ref_model = None
|
| 816 |
+
else:
|
| 817 |
+
self.ref_model = create_reference_model(self.policy_model)
|
| 818 |
+
|
| 819 |
+
self.reward_model = reward_model
|
| 820 |
+
self.train_dataset = train_dataset
|
| 821 |
+
self.train_dataset_len = len(train_dataset)
|
| 822 |
+
self.value_model = value_model
|
| 823 |
+
self.data_collator = data_collator
|
| 824 |
+
self.eval_dataset = eval_dataset
|
| 825 |
+
self.optimizer, self.lr_scheduler = optimizers
|
| 826 |
+
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
| 827 |
+
|
| 828 |
+
#########
|
| 829 |
+
# calculate various batch sizes
|
| 830 |
+
#########
|
| 831 |
+
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
| 832 |
+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
| 833 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
| 834 |
+
self.accelerator = accelerator
|
| 835 |
+
args.world_size = accelerator.num_processes
|
| 836 |
+
args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
|
| 837 |
+
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
| 838 |
+
args.batch_size = int(args.local_batch_size * args.world_size)
|
| 839 |
+
args.mini_batch_size = exact_div(
|
| 840 |
+
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
| 841 |
+
)
|
| 842 |
+
args.local_mini_batch_size = exact_div(
|
| 843 |
+
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
| 844 |
+
)
|
| 845 |
+
if args.whiten_rewards:
|
| 846 |
+
assert args.local_mini_batch_size >= 8, (
|
| 847 |
+
f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
|
| 848 |
+
)
|
| 849 |
+
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
|
| 850 |
+
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
|
| 851 |
+
args.num_total_batches = math.ceil(
|
| 852 |
+
args.total_episodes / args.batch_size
|
| 853 |
+
) # we may train for more than `total_episodes`
|
| 854 |
+
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
| 855 |
+
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
| 856 |
+
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
| 857 |
+
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
| 858 |
+
if args.num_sample_generations > 0:
|
| 859 |
+
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
| 860 |
+
self.local_dataloader_batch_size = args.local_batch_size
|
| 861 |
+
|
| 862 |
+
#########
|
| 863 |
+
# setup model, optimizer, and others
|
| 864 |
+
#########
|
| 865 |
+
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
|
| 866 |
+
if module is not None:
|
| 867 |
+
disable_dropout_in_model(module)
|
| 868 |
+
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
|
| 869 |
+
self.model.config = self.policy_model.config # needed for pushing to hub
|
| 870 |
+
self.create_optimizer_and_scheduler(
|
| 871 |
+
num_training_steps=args.num_total_batches
|
| 872 |
+
) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
|
| 873 |
+
|
| 874 |
+
#########
|
| 875 |
+
# trainer specifics
|
| 876 |
+
#########
|
| 877 |
+
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
| 878 |
+
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
| 879 |
+
self.callback_handler = CallbackHandler(
|
| 880 |
+
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
| 881 |
+
)
|
| 882 |
+
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
| 883 |
+
self.control = TrainerControl()
|
| 884 |
+
self.state = OnlineTrainerState(
|
| 885 |
+
is_local_process_zero=self.is_local_process_zero(),
|
| 886 |
+
is_world_process_zero=self.is_world_process_zero(),
|
| 887 |
+
stateful_callbacks=[
|
| 888 |
+
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
| 889 |
+
],
|
| 890 |
+
)
|
| 891 |
+
self.current_flos = 0
|
| 892 |
+
self.hp_search_backend = None
|
| 893 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
| 894 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
| 895 |
+
# Create distant repo and output directory if needed
|
| 896 |
+
self.hub_model_id = None
|
| 897 |
+
if self.args.push_to_hub:
|
| 898 |
+
self.init_hf_repo()
|
| 899 |
+
if self.args.should_save:
|
| 900 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 901 |
+
|
| 902 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 903 |
+
if hasattr(self.model, "add_model_tags"):
|
| 904 |
+
self.model.add_model_tags(self._tag_names)
|
| 905 |
+
|
| 906 |
+
#########
|
| 907 |
+
# setup dataloader
|
| 908 |
+
#########
|
| 909 |
+
self.dataloader = DataLoader(
|
| 910 |
+
self.train_dataset,
|
| 911 |
+
batch_size=self.local_dataloader_batch_size,
|
| 912 |
+
shuffle=True,
|
| 913 |
+
collate_fn=self.data_collator,
|
| 914 |
+
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
| 915 |
+
)
|
| 916 |
+
# sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
|
| 917 |
+
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
| 918 |
+
torch.manual_seed(args.seed)
|
| 919 |
+
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
| 920 |
+
torch.manual_seed(self.local_seed) # reset the local seed again
|
| 921 |
+
|
| 922 |
+
self.eval_dataloader = DataLoader(
|
| 923 |
+
self.eval_dataset,
|
| 924 |
+
batch_size=args.per_device_eval_batch_size,
|
| 925 |
+
collate_fn=self.data_collator,
|
| 926 |
+
drop_last=True,
|
| 927 |
+
) # no need to shuffle eval dataset
|
| 928 |
+
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
| 929 |
+
|
| 930 |
+
if self.is_deepspeed_enabled:
|
| 931 |
+
self.reward_model = prepare_deepspeed(
|
| 932 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
if self.ref_model is None:
|
| 936 |
+
if not self.is_peft_model:
|
| 937 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
| 938 |
+
else:
|
| 939 |
+
self.ref_model = prepare_deepspeed(
|
| 940 |
+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 941 |
+
)
|
| 942 |
+
else:
|
| 943 |
+
if self.ref_model is None:
|
| 944 |
+
if not self.is_peft_model:
|
| 945 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
| 946 |
+
else:
|
| 947 |
+
self.ref_model = self.ref_model.to(self.accelerator.device)
|
| 948 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
| 949 |
+
|
| 950 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 951 |
+
return self.dataloader
|
| 952 |
+
|
| 953 |
+
def get_eval_dataloader(self) -> DataLoader:
|
| 954 |
+
return self.eval_dataloader
|
| 955 |
+
|
| 956 |
+
@contextmanager
|
| 957 |
+
def null_ref_context(self):
|
| 958 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 959 |
+
with (
|
| 960 |
+
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
|
| 961 |
+
if self.is_peft_model and not self.ref_adapter_name
|
| 962 |
+
else nullcontext()
|
| 963 |
+
):
|
| 964 |
+
if self.ref_adapter_name:
|
| 965 |
+
self.model.policy.set_adapter(self.ref_adapter_name)
|
| 966 |
+
yield
|
| 967 |
+
if self.ref_adapter_name:
|
| 968 |
+
self.model.policy.set_adapter(self.model_adapter_name or "default")
|
| 969 |
+
|
| 970 |
+
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
| 971 |
+
backup_model = self.model
|
| 972 |
+
self.model = self.model.policy # save only the policy
|
| 973 |
+
|
| 974 |
+
if self.is_deepspeed_enabled:
|
| 975 |
+
backup_deepspeed = self.deepspeed
|
| 976 |
+
self.deepspeed = self.model
|
| 977 |
+
|
| 978 |
+
super().save_model(output_dir, _internal_call)
|
| 979 |
+
|
| 980 |
+
self.model = backup_model
|
| 981 |
+
|
| 982 |
+
if self.is_deepspeed_enabled:
|
| 983 |
+
self.deepspeed = backup_deepspeed
|
| 984 |
+
|
| 985 |
+
def train(self):
|
| 986 |
+
args = self.args
|
| 987 |
+
accelerator = self.accelerator
|
| 988 |
+
optimizer = self.optimizer
|
| 989 |
+
model = self.model
|
| 990 |
+
ref_policy = self.ref_model
|
| 991 |
+
reward_model = self.reward_model
|
| 992 |
+
processing_class = self.processing_class
|
| 993 |
+
dataloader = self.dataloader
|
| 994 |
+
device = accelerator.device
|
| 995 |
+
|
| 996 |
+
def repeat_generator():
|
| 997 |
+
while True:
|
| 998 |
+
yield from dataloader
|
| 999 |
+
|
| 1000 |
+
iter_dataloader = iter(repeat_generator())
|
| 1001 |
+
generation_config = GenerationConfig(
|
| 1002 |
+
max_new_tokens=args.response_length,
|
| 1003 |
+
temperature=(args.temperature + 1e-7),
|
| 1004 |
+
top_k=0.0,
|
| 1005 |
+
top_p=1.0,
|
| 1006 |
+
do_sample=True,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
accelerator.print("===training policy===")
|
| 1010 |
+
start_time = time.time()
|
| 1011 |
+
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
| 1012 |
+
approxkl_stats = torch.zeros(stats_shape, device=device)
|
| 1013 |
+
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 1014 |
+
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
| 1015 |
+
vf_loss_stats = torch.zeros(stats_shape, device=device)
|
| 1016 |
+
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 1017 |
+
entropy_stats = torch.zeros(stats_shape, device=device)
|
| 1018 |
+
ratio_stats = torch.zeros(stats_shape, device=device)
|
| 1019 |
+
model.train()
|
| 1020 |
+
|
| 1021 |
+
# trainer state initialization
|
| 1022 |
+
self.state.global_step = 0
|
| 1023 |
+
self.state.episode = 0
|
| 1024 |
+
self.state.max_steps = args.num_total_batches
|
| 1025 |
+
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
| 1026 |
+
# Compute absolute values for logging, eval, and save if given as ratio
|
| 1027 |
+
if args.logging_steps is not None:
|
| 1028 |
+
if args.logging_steps < 1:
|
| 1029 |
+
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
| 1030 |
+
else:
|
| 1031 |
+
self.state.logging_steps = args.logging_steps
|
| 1032 |
+
if args.eval_steps is not None:
|
| 1033 |
+
if args.eval_steps < 1:
|
| 1034 |
+
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
| 1035 |
+
else:
|
| 1036 |
+
self.state.eval_steps = args.eval_steps
|
| 1037 |
+
if args.save_steps is not None:
|
| 1038 |
+
if args.save_steps < 1:
|
| 1039 |
+
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
| 1040 |
+
else:
|
| 1041 |
+
self.state.save_steps = args.save_steps
|
| 1042 |
+
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
| 1043 |
+
|
| 1044 |
+
# backward compatibility
|
| 1045 |
+
if self.is_deepspeed_enabled:
|
| 1046 |
+
self.deepspeed = self.model
|
| 1047 |
+
self.model_wrapped = self.model
|
| 1048 |
+
|
| 1049 |
+
for update in range(1, args.num_total_batches + 1):
|
| 1050 |
+
self.state.episode += 1 * args.batch_size
|
| 1051 |
+
data = next(iter_dataloader)
|
| 1052 |
+
with torch.no_grad():
|
| 1053 |
+
queries = data["input_ids"].to(device)
|
| 1054 |
+
context_length = queries.shape[1]
|
| 1055 |
+
responses = []
|
| 1056 |
+
postprocessed_responses = []
|
| 1057 |
+
logprobs = []
|
| 1058 |
+
ref_logprobs = []
|
| 1059 |
+
scores = []
|
| 1060 |
+
sequence_lengths = []
|
| 1061 |
+
values = []
|
| 1062 |
+
with unwrap_model_for_generation(
|
| 1063 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 1064 |
+
) as unwrapped_model:
|
| 1065 |
+
query_responses, logitss = batch_generation(
|
| 1066 |
+
unwrapped_model.policy,
|
| 1067 |
+
queries,
|
| 1068 |
+
args.local_rollout_forward_batch_size,
|
| 1069 |
+
processing_class.pad_token_id,
|
| 1070 |
+
generation_config,
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
| 1074 |
+
query = queries[i : i + args.local_rollout_forward_batch_size]
|
| 1075 |
+
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
| 1076 |
+
response = query_response[:, context_length:]
|
| 1077 |
+
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
| 1078 |
+
logprob = selective_log_softmax(logits, response)
|
| 1079 |
+
del logits
|
| 1080 |
+
empty_cache()
|
| 1081 |
+
|
| 1082 |
+
if ref_policy is None:
|
| 1083 |
+
with self.null_ref_context():
|
| 1084 |
+
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
|
| 1085 |
+
else:
|
| 1086 |
+
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
| 1087 |
+
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
| 1088 |
+
ref_logits /= args.temperature + 1e-7
|
| 1089 |
+
ref_logprob = selective_log_softmax(ref_logits, response)
|
| 1090 |
+
del ref_output, ref_logits
|
| 1091 |
+
empty_cache()
|
| 1092 |
+
|
| 1093 |
+
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
| 1094 |
+
postprocessed_response = response
|
| 1095 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 1096 |
+
postprocessed_response = truncate_response(
|
| 1097 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
# Response Processing 2. run reward model on the truncated responses
|
| 1101 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 1102 |
+
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
| 1103 |
+
unwrapped_value_model = accelerator.unwrap_model(model).value_model
|
| 1104 |
+
full_value, _, _ = get_reward(
|
| 1105 |
+
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
|
| 1106 |
+
)
|
| 1107 |
+
value = full_value[:, context_length - 1 : -1].squeeze(-1)
|
| 1108 |
+
_, score, _ = get_reward(
|
| 1109 |
+
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
responses.append(response)
|
| 1113 |
+
postprocessed_responses.append(postprocessed_response)
|
| 1114 |
+
logprobs.append(logprob)
|
| 1115 |
+
ref_logprobs.append(ref_logprob)
|
| 1116 |
+
sequence_lengths.append(sequence_length)
|
| 1117 |
+
scores.append(score)
|
| 1118 |
+
values.append(value)
|
| 1119 |
+
responses = torch.cat(responses, 0)
|
| 1120 |
+
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
| 1121 |
+
logprobs = torch.cat(logprobs, 0)
|
| 1122 |
+
ref_logprobs = torch.cat(ref_logprobs, 0)
|
| 1123 |
+
sequence_lengths = torch.cat(sequence_lengths, 0)
|
| 1124 |
+
scores = torch.cat(scores, 0)
|
| 1125 |
+
values = torch.cat(values, 0)
|
| 1126 |
+
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
| 1127 |
+
empty_cache()
|
| 1128 |
+
gc.collect()
|
| 1129 |
+
|
| 1130 |
+
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
| 1131 |
+
# Completions not passing that filter will receive a lower score.
|
| 1132 |
+
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
|
| 1133 |
+
if self.args.missing_eos_penalty is not None:
|
| 1134 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
| 1135 |
+
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
| 1136 |
+
|
| 1137 |
+
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
| 1138 |
+
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
| 1139 |
+
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
| 1140 |
+
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
| 1141 |
+
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
| 1142 |
+
sequence_lengths_p1 = sequence_lengths + 1
|
| 1143 |
+
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
|
| 1144 |
+
values = torch.masked_fill(values, padding_mask_p1, 0)
|
| 1145 |
+
|
| 1146 |
+
# 4. compute rewards
|
| 1147 |
+
# Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
|
| 1148 |
+
logr = ref_logprobs - logprobs
|
| 1149 |
+
kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
|
| 1150 |
+
non_score_reward = -args.kl_coef * kl
|
| 1151 |
+
rewards = non_score_reward.clone()
|
| 1152 |
+
actual_start = torch.arange(rewards.size(0), device=rewards.device)
|
| 1153 |
+
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
|
| 1154 |
+
rewards[[actual_start, actual_end]] += scores
|
| 1155 |
+
|
| 1156 |
+
# 5. whiten rewards
|
| 1157 |
+
if args.whiten_rewards:
|
| 1158 |
+
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
|
| 1159 |
+
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
|
| 1160 |
+
|
| 1161 |
+
# 6. compute advantages and returns
|
| 1162 |
+
lastgaelam = 0
|
| 1163 |
+
advantages_reversed = []
|
| 1164 |
+
gen_length = responses.shape[1]
|
| 1165 |
+
for t in reversed(range(gen_length)):
|
| 1166 |
+
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
|
| 1167 |
+
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
|
| 1168 |
+
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
| 1169 |
+
advantages_reversed.append(lastgaelam)
|
| 1170 |
+
advantages = torch.stack(advantages_reversed[::-1], axis=1)
|
| 1171 |
+
returns = advantages + values
|
| 1172 |
+
advantages = masked_whiten(advantages, ~padding_mask)
|
| 1173 |
+
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
| 1174 |
+
empty_cache()
|
| 1175 |
+
|
| 1176 |
+
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
| 1177 |
+
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
| 1178 |
+
b_inds = np.random.permutation(args.local_batch_size)
|
| 1179 |
+
minibatch_idx = 0
|
| 1180 |
+
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
| 1181 |
+
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
| 1182 |
+
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
| 1183 |
+
gradient_accumulation_idx = 0
|
| 1184 |
+
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
| 1185 |
+
with accelerator.accumulate(model):
|
| 1186 |
+
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
| 1187 |
+
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
| 1188 |
+
mb_advantage = advantages[micro_batch_inds]
|
| 1189 |
+
mb_responses = responses[micro_batch_inds]
|
| 1190 |
+
mb_query_responses = query_responses[micro_batch_inds]
|
| 1191 |
+
mb_logprobs = logprobs[micro_batch_inds]
|
| 1192 |
+
mb_return = returns[micro_batch_inds]
|
| 1193 |
+
mb_values = values[micro_batch_inds]
|
| 1194 |
+
|
| 1195 |
+
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
|
| 1196 |
+
logits = output.logits[:, context_length - 1 : -1]
|
| 1197 |
+
logits /= args.temperature + 1e-7
|
| 1198 |
+
new_logprobs = selective_log_softmax(logits, mb_responses)
|
| 1199 |
+
new_logprobs = torch.masked_fill(
|
| 1200 |
+
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
| 1201 |
+
)
|
| 1202 |
+
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
|
| 1203 |
+
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
|
| 1204 |
+
vpredclipped = torch.clamp(
|
| 1205 |
+
vpred,
|
| 1206 |
+
mb_values - args.cliprange_value,
|
| 1207 |
+
mb_values + args.cliprange_value,
|
| 1208 |
+
)
|
| 1209 |
+
vf_losses1 = torch.square(vpred - mb_return)
|
| 1210 |
+
vf_losses2 = torch.square(vpredclipped - mb_return)
|
| 1211 |
+
vf_loss_max = torch.max(vf_losses1, vf_losses2)
|
| 1212 |
+
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
|
| 1213 |
+
vf_clipfrac = masked_mean(
|
| 1214 |
+
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
|
| 1215 |
+
)
|
| 1216 |
+
logprobs_diff = new_logprobs - mb_logprobs
|
| 1217 |
+
ratio = torch.exp(logprobs_diff)
|
| 1218 |
+
pg_losses = -mb_advantage * ratio
|
| 1219 |
+
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
| 1220 |
+
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
| 1221 |
+
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
|
| 1222 |
+
loss = pg_loss + args.vf_coef * vf_loss
|
| 1223 |
+
accelerator.backward(loss)
|
| 1224 |
+
optimizer.step()
|
| 1225 |
+
optimizer.zero_grad()
|
| 1226 |
+
with torch.no_grad():
|
| 1227 |
+
pg_clipfrac = masked_mean(
|
| 1228 |
+
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
|
| 1229 |
+
)
|
| 1230 |
+
prob_dist = torch.nn.functional.softmax(logits, dim=-1, dtype = torch.float32).to(logits.dtype)
|
| 1231 |
+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
| 1232 |
+
approxkl = 0.5 * (logprobs_diff**2).mean()
|
| 1233 |
+
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
| 1234 |
+
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 1235 |
+
pg_clipfrac
|
| 1236 |
+
)
|
| 1237 |
+
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
| 1238 |
+
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
|
| 1239 |
+
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 1240 |
+
vf_clipfrac
|
| 1241 |
+
)
|
| 1242 |
+
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
| 1243 |
+
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
|
| 1244 |
+
gradient_accumulation_idx += 1
|
| 1245 |
+
minibatch_idx += 1
|
| 1246 |
+
# del everything and empty cache
|
| 1247 |
+
# fmt: off
|
| 1248 |
+
del (
|
| 1249 |
+
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
|
| 1250 |
+
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
|
| 1251 |
+
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
|
| 1252 |
+
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
| 1253 |
+
)
|
| 1254 |
+
# fmt: on
|
| 1255 |
+
empty_cache()
|
| 1256 |
+
with torch.no_grad():
|
| 1257 |
+
mean_kl = kl.sum(1).mean()
|
| 1258 |
+
mean_entropy = (-logprobs).sum(1).mean()
|
| 1259 |
+
mean_non_score_reward = non_score_reward.sum(1).mean()
|
| 1260 |
+
rlhf_reward = mean_non_score_reward + scores.mean()
|
| 1261 |
+
eps = int(self.state.episode / (time.time() - start_time))
|
| 1262 |
+
metrics = {}
|
| 1263 |
+
metrics["eps"] = eps
|
| 1264 |
+
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
| 1265 |
+
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
| 1266 |
+
metrics["objective/non_score_reward"] = (
|
| 1267 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
| 1268 |
+
)
|
| 1269 |
+
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
| 1270 |
+
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
| 1271 |
+
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
| 1272 |
+
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
| 1273 |
+
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
| 1274 |
+
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
|
| 1275 |
+
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
| 1276 |
+
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
| 1277 |
+
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
| 1278 |
+
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
| 1279 |
+
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
| 1280 |
+
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
| 1281 |
+
metrics["episode"] = self.state.episode
|
| 1282 |
+
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
|
| 1283 |
+
self.state.global_step += 1
|
| 1284 |
+
self.log(metrics)
|
| 1285 |
+
|
| 1286 |
+
self.lr_scheduler.step()
|
| 1287 |
+
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
| 1288 |
+
if self.control.should_save:
|
| 1289 |
+
self._save_checkpoint(model, trial=None)
|
| 1290 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 1291 |
+
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
| 1292 |
+
empty_cache()
|
| 1293 |
+
gc.collect()
|
| 1294 |
+
|
| 1295 |
+
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
| 1296 |
+
self.generate_completions(sampling=True)
|
| 1297 |
+
empty_cache()
|
| 1298 |
+
del (
|
| 1299 |
+
query_responses,
|
| 1300 |
+
responses,
|
| 1301 |
+
postprocessed_responses,
|
| 1302 |
+
logprobs,
|
| 1303 |
+
ref_logprobs,
|
| 1304 |
+
values,
|
| 1305 |
+
sequence_lengths,
|
| 1306 |
+
contain_eos_token,
|
| 1307 |
+
sequence_lengths_p1,
|
| 1308 |
+
response_idxs,
|
| 1309 |
+
padding_mask,
|
| 1310 |
+
padding_mask_p1,
|
| 1311 |
+
rewards,
|
| 1312 |
+
actual_start,
|
| 1313 |
+
actual_end,
|
| 1314 |
+
advantages,
|
| 1315 |
+
returns,
|
| 1316 |
+
)
|
| 1317 |
+
empty_cache()
|
| 1318 |
+
|
| 1319 |
+
# HF trainer specifics
|
| 1320 |
+
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
| 1321 |
+
if self.control.should_save:
|
| 1322 |
+
self._save_checkpoint(model, trial=None)
|
| 1323 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 1324 |
+
|
| 1325 |
+
def generate_completions(self, sampling: bool = False):
|
| 1326 |
+
args = self.args
|
| 1327 |
+
processing_class = self.processing_class
|
| 1328 |
+
generation_config = GenerationConfig(
|
| 1329 |
+
max_new_tokens=self.args.response_length,
|
| 1330 |
+
temperature=(0.01 + 1e-7),
|
| 1331 |
+
top_k=0.0,
|
| 1332 |
+
top_p=1.0,
|
| 1333 |
+
do_sample=True,
|
| 1334 |
+
)
|
| 1335 |
+
|
| 1336 |
+
table = defaultdict(list)
|
| 1337 |
+
with unwrap_model_for_generation(
|
| 1338 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 1339 |
+
) as unwrapped_model:
|
| 1340 |
+
for batch in self.eval_dataloader:
|
| 1341 |
+
query = batch["input_ids"]
|
| 1342 |
+
with torch.no_grad():
|
| 1343 |
+
context_length = query.shape[1]
|
| 1344 |
+
query_response, _ = batch_generation(
|
| 1345 |
+
unwrapped_model.policy,
|
| 1346 |
+
query,
|
| 1347 |
+
query.shape[0],
|
| 1348 |
+
processing_class.pad_token_id,
|
| 1349 |
+
generation_config,
|
| 1350 |
+
)
|
| 1351 |
+
response = query_response[:, context_length:]
|
| 1352 |
+
postprocessed_response = response
|
| 1353 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 1354 |
+
postprocessed_response = truncate_response(
|
| 1355 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
| 1356 |
+
)
|
| 1357 |
+
table["query"].extend(
|
| 1358 |
+
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
| 1359 |
+
)
|
| 1360 |
+
table["model response"].extend(
|
| 1361 |
+
gather_object(processing_class.batch_decode(postprocessed_response))
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 1365 |
+
_, score, _ = get_reward(
|
| 1366 |
+
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 1367 |
+
)
|
| 1368 |
+
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
| 1369 |
+
|
| 1370 |
+
if sampling:
|
| 1371 |
+
break
|
| 1372 |
+
df = pd.DataFrame(table)
|
| 1373 |
+
|
| 1374 |
+
if self.accelerator.is_main_process:
|
| 1375 |
+
if is_rich_available():
|
| 1376 |
+
print_rich_table(df.iloc[0 : 0 + 5])
|
| 1377 |
+
if "wandb" in args.report_to:
|
| 1378 |
+
import wandb
|
| 1379 |
+
|
| 1380 |
+
if wandb.run is not None:
|
| 1381 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 1382 |
+
|
| 1383 |
+
if "comet_ml" in args.report_to:
|
| 1384 |
+
log_table_to_comet_experiment(
|
| 1385 |
+
name="completions.csv",
|
| 1386 |
+
table=df,
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
# Ensure the model card is saved along with the checkpoint
|
| 1390 |
+
def _save_checkpoint(self, model, trial):
|
| 1391 |
+
if self.args.hub_model_id is None:
|
| 1392 |
+
model_name = Path(self.args.output_dir).name
|
| 1393 |
+
else:
|
| 1394 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 1395 |
+
self.create_model_card(model_name=model_name)
|
| 1396 |
+
super()._save_checkpoint(model, trial)
|
| 1397 |
+
class UnslothPPOTrainer(_UnslothPPOTrainer):
|
| 1398 |
+
"""
|
| 1399 |
+
Trainer for Proximal Policy Optimization (PPO).
|
| 1400 |
+
|
| 1401 |
+
For details on PPO, see the paper: [Proximal Policy Optimization
|
| 1402 |
+
Algorithms](https://huggingface.co/papers/1707.06347).
|
| 1403 |
+
|
| 1404 |
+
Args:
|
| 1405 |
+
args ([`PPOConfig`]):
|
| 1406 |
+
Training arguments.
|
| 1407 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]):
|
| 1408 |
+
Class to process the data.
|
| 1409 |
+
model (`torch.nn.Module`):
|
| 1410 |
+
Model to be trained. This is the policy model.
|
| 1411 |
+
ref_model (`torch.nn.Module`, *optional*):
|
| 1412 |
+
Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created.
|
| 1413 |
+
reward_model (`torch.nn.Module`):
|
| 1414 |
+
Reward model used to compute the rewards.
|
| 1415 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 1416 |
+
Dataset for training.
|
| 1417 |
+
value_model (`torch.nn.Module`):
|
| 1418 |
+
Value model used to predict the value of a state.
|
| 1419 |
+
data_collator ([`~transformers.DataCollatorWithPadding`], *optional*):
|
| 1420 |
+
Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created
|
| 1421 |
+
using the `processing_class`.
|
| 1422 |
+
eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
|
| 1423 |
+
Dataset for evaluation.
|
| 1424 |
+
optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
|
| 1425 |
+
Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the
|
| 1426 |
+
optimizer and the learning rate scheduler are created using the
|
| 1427 |
+
[`~transformers.Trainer.create_optimizer_and_scheduler`] method.
|
| 1428 |
+
callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
|
| 1429 |
+
Callbacks to use during training.
|
| 1430 |
+
peft_config ([`~peft.PeftConfig`], *optional*):
|
| 1431 |
+
PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model`
|
| 1432 |
+
will be wrapped with the specified PEFT adapter.
|
| 1433 |
+
|
| 1434 |
+
"""
|
| 1435 |
+
def __init__(
|
| 1436 |
+
self,
|
| 1437 |
+
args,
|
| 1438 |
+
processing_class,
|
| 1439 |
+
model,
|
| 1440 |
+
ref_model,
|
| 1441 |
+
reward_model,
|
| 1442 |
+
train_dataset,
|
| 1443 |
+
value_model,
|
| 1444 |
+
data_collator = None,
|
| 1445 |
+
eval_dataset = None,
|
| 1446 |
+
callbacks = None,
|
| 1447 |
+
peft_config = None,
|
| 1448 |
+
**kwargs
|
| 1449 |
+
):
|
| 1450 |
+
if args is None: args = UnslothPPOConfig()
|
| 1451 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1452 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1453 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1454 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1455 |
+
force_float32 = False
|
| 1456 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1457 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1458 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1459 |
+
force_float32 = True
|
| 1460 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1461 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1462 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1463 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1464 |
+
dtype = _get_dtype(dtype)
|
| 1465 |
+
float16 = dtype == torch.float16
|
| 1466 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1467 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1468 |
+
if force_float32:
|
| 1469 |
+
# Forced float32 training
|
| 1470 |
+
args.fp16 = False
|
| 1471 |
+
args.bf16 = False
|
| 1472 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1473 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1474 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1475 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1476 |
+
# Mixed precision training
|
| 1477 |
+
args.fp16 = float16
|
| 1478 |
+
args.bf16 = not float16
|
| 1479 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1480 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1481 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1482 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1483 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1484 |
+
args.fp16 = False
|
| 1485 |
+
args.bf16 = False
|
| 1486 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1487 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1488 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1489 |
+
|
| 1490 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1491 |
+
args.eval_strategy = 'steps'
|
| 1492 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1493 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1494 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1495 |
+
from transformers import __version__ as transformers_version
|
| 1496 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1497 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1498 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1499 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1500 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1501 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1502 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1503 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1504 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1505 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1506 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1507 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1508 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1509 |
+
if force_float32:
|
| 1510 |
+
args.bf16_full_eval = False
|
| 1511 |
+
args.fp16_full_eval = False
|
| 1512 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1513 |
+
args.bf16_full_eval = True
|
| 1514 |
+
args.fp16_full_eval = False
|
| 1515 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1516 |
+
args.bf16_full_eval = args.bf16
|
| 1517 |
+
args.fp16_full_eval = args.fp16
|
| 1518 |
+
_output_logits = False
|
| 1519 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1520 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1521 |
+
if _output_logits:
|
| 1522 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1523 |
+
if model is not None:
|
| 1524 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1525 |
+
if _warnings_issued is None:
|
| 1526 |
+
model.warnings_issued = {}
|
| 1527 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1528 |
+
try:
|
| 1529 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1530 |
+
except Exception:
|
| 1531 |
+
model.warnings_issued = {}
|
| 1532 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1533 |
+
pass
|
| 1534 |
+
else:
|
| 1535 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1536 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1537 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1538 |
+
max_seq_length = model.max_seq_length
|
| 1539 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1540 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1541 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1542 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1543 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1544 |
+
args.max_seq_length = model_max_seq_length
|
| 1545 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1546 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1547 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1548 |
+
if 'processing_class' in locals():
|
| 1549 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1550 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1551 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1552 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1553 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1554 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1555 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1556 |
+
__tokenizer,
|
| 1557 |
+
mlm = False,
|
| 1558 |
+
mlm_probability = 0.0,
|
| 1559 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1560 |
+
)
|
| 1561 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1562 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1563 |
+
__tokenizer,
|
| 1564 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1565 |
+
)
|
| 1566 |
+
else:
|
| 1567 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1568 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1569 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1570 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1571 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1572 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1573 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1574 |
+
__tokenizer.tokenizer,
|
| 1575 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1576 |
+
)
|
| 1577 |
+
else:
|
| 1578 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1579 |
+
__tokenizer.tokenizer,
|
| 1580 |
+
mlm = False,
|
| 1581 |
+
mlm_probability = 0.0,
|
| 1582 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1583 |
+
)
|
| 1584 |
+
other_metrics = []
|
| 1585 |
+
|
| 1586 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1587 |
+
PatchRLStatistics('ppo_trainer', other_metrics)
|
| 1588 |
+
|
| 1589 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1590 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1591 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1592 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1593 |
+
args._n_gpu = 1
|
| 1594 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1595 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1596 |
+
super().__init__(
|
| 1597 |
+
args = args,
|
| 1598 |
+
processing_class = processing_class,
|
| 1599 |
+
model = model,
|
| 1600 |
+
ref_model = ref_model,
|
| 1601 |
+
reward_model = reward_model,
|
| 1602 |
+
train_dataset = train_dataset,
|
| 1603 |
+
value_model = value_model,
|
| 1604 |
+
data_collator = data_collator,
|
| 1605 |
+
eval_dataset = eval_dataset,
|
| 1606 |
+
callbacks = callbacks,
|
| 1607 |
+
peft_config = peft_config,**kwargs)
|
| 1608 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1609 |
+
model.for_inference()
|
| 1610 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1611 |
+
self.neftune_hook_handle.remove()
|
| 1612 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1613 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1614 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1615 |
+
pass
|
| 1616 |
+
if hasattr(self, 'accelerator'):
|
| 1617 |
+
scaler = self.accelerator.scaler
|
| 1618 |
+
current_model = model
|
| 1619 |
+
while hasattr(current_model, 'model'):
|
| 1620 |
+
current_model.accelerator_scaler = scaler
|
| 1621 |
+
current_model = current_model.model
|
| 1622 |
+
current_model.accelerator_scaler = scaler
|
| 1623 |
+
pass
|
| 1624 |
+
if hasattr(self, 'train'):
|
| 1625 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1626 |
+
pass
|
| 1627 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1628 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1629 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1630 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1631 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1632 |
+
pass
|
| 1633 |
+
|
| 1634 |
+
pass
|
unsloth_compiled_cache/UnslothPRMTrainer.py
ADDED
|
@@ -0,0 +1,1109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.prm_trainer import (BaseImageProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, Path, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, nn, os, textwrap, torch, warnings, BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PartialState, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, compute_accuracy, disable_dropout_in_model, features, nn, os, torch, warnings, PreTrainedModel, os, torch)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothPRMConfig(PRMConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`PRMTrainer`].
|
| 328 |
+
|
| 329 |
+
This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
|
| 330 |
+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
| 331 |
+
differ from those in [`~transformers.TrainingArguments`].
|
| 332 |
+
|
| 333 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 334 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 335 |
+
command line.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 339 |
+
Maximum length of the sequences (prompt + completion) used for truncation.
|
| 340 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 341 |
+
Maximum length of the prompt used for truncation.
|
| 342 |
+
max_completion_length (`int`, *optional*):
|
| 343 |
+
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
| 344 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 345 |
+
Whether to disable dropout in the model.
|
| 346 |
+
step_separator (`str`, *optional*, defaults to `"\n"`):
|
| 347 |
+
Separator used to separate each step of the reasoning process.
|
| 348 |
+
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
| 349 |
+
Whether to train only on the last step.
|
| 350 |
+
dataset_num_proc (`int`, *optional*):
|
| 351 |
+
Number of processes to use for processing the dataset.
|
| 352 |
+
|
| 353 |
+
"""
|
| 354 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 355 |
+
default = None,
|
| 356 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 357 |
+
)
|
| 358 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 359 |
+
default = -1,
|
| 360 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 361 |
+
)
|
| 362 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 363 |
+
default = None,
|
| 364 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 365 |
+
)
|
| 366 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 367 |
+
default = None,
|
| 368 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 369 |
+
)
|
| 370 |
+
max_seq_length : Optional[int] = field(
|
| 371 |
+
default = None,
|
| 372 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 373 |
+
)
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
output_dir = None,
|
| 377 |
+
per_device_train_batch_size = 4,
|
| 378 |
+
num_train_epochs = 3.0,
|
| 379 |
+
max_steps = -1,
|
| 380 |
+
learning_rate = 5e-05,
|
| 381 |
+
lr_scheduler_type = 'linear',
|
| 382 |
+
lr_scheduler_kwargs = None,
|
| 383 |
+
warmup_steps = 0.1,
|
| 384 |
+
optim = 'adamw_8bit',
|
| 385 |
+
optim_args = None,
|
| 386 |
+
weight_decay = 0.01,
|
| 387 |
+
adam_beta1 = 0.9,
|
| 388 |
+
adam_beta2 = 0.999,
|
| 389 |
+
adam_epsilon = 1e-08,
|
| 390 |
+
optim_target_modules = None,
|
| 391 |
+
gradient_accumulation_steps = 2,
|
| 392 |
+
average_tokens_across_devices = True,
|
| 393 |
+
max_grad_norm = 1.0,
|
| 394 |
+
label_smoothing_factor = 0.0,
|
| 395 |
+
bf16 = False,
|
| 396 |
+
fp16 = False,
|
| 397 |
+
bf16_full_eval = False,
|
| 398 |
+
fp16_full_eval = False,
|
| 399 |
+
tf32 = None,
|
| 400 |
+
gradient_checkpointing = True,
|
| 401 |
+
gradient_checkpointing_kwargs = None,
|
| 402 |
+
torch_compile = False,
|
| 403 |
+
torch_compile_backend = None,
|
| 404 |
+
torch_compile_mode = None,
|
| 405 |
+
use_liger_kernel = False,
|
| 406 |
+
liger_kernel_config = None,
|
| 407 |
+
use_cache = False,
|
| 408 |
+
neftune_noise_alpha = None,
|
| 409 |
+
torch_empty_cache_steps = 250,
|
| 410 |
+
auto_find_batch_size = False,
|
| 411 |
+
logging_strategy = 'steps',
|
| 412 |
+
logging_steps = 1,
|
| 413 |
+
logging_first_step = False,
|
| 414 |
+
log_on_each_node = True,
|
| 415 |
+
logging_nan_inf_filter = False,
|
| 416 |
+
include_num_input_tokens_seen = False,
|
| 417 |
+
log_level = 'passive',
|
| 418 |
+
log_level_replica = 'warning',
|
| 419 |
+
disable_tqdm = None,
|
| 420 |
+
report_to = 'none',
|
| 421 |
+
run_name = None,
|
| 422 |
+
project = 'huggingface',
|
| 423 |
+
trackio_space_id = 'trackio',
|
| 424 |
+
eval_strategy = 'no',
|
| 425 |
+
eval_steps = None,
|
| 426 |
+
eval_delay = 0,
|
| 427 |
+
per_device_eval_batch_size = 4,
|
| 428 |
+
prediction_loss_only = False,
|
| 429 |
+
eval_on_start = False,
|
| 430 |
+
eval_do_concat_batches = True,
|
| 431 |
+
eval_use_gather_object = False,
|
| 432 |
+
eval_accumulation_steps = 2,
|
| 433 |
+
batch_eval_metrics = False,
|
| 434 |
+
save_only_model = False,
|
| 435 |
+
save_strategy = 'steps',
|
| 436 |
+
save_steps = 500,
|
| 437 |
+
save_on_each_node = False,
|
| 438 |
+
save_total_limit = None,
|
| 439 |
+
enable_jit_checkpoint = False,
|
| 440 |
+
push_to_hub = False,
|
| 441 |
+
hub_token = None,
|
| 442 |
+
hub_private_repo = None,
|
| 443 |
+
hub_model_id = None,
|
| 444 |
+
hub_strategy = 'every_save',
|
| 445 |
+
hub_always_push = False,
|
| 446 |
+
hub_revision = None,
|
| 447 |
+
load_best_model_at_end = False,
|
| 448 |
+
metric_for_best_model = None,
|
| 449 |
+
greater_is_better = None,
|
| 450 |
+
ignore_data_skip = False,
|
| 451 |
+
restore_callback_states_from_checkpoint = False,
|
| 452 |
+
full_determinism = False,
|
| 453 |
+
seed = 3407,
|
| 454 |
+
data_seed = 3407,
|
| 455 |
+
use_cpu = False,
|
| 456 |
+
accelerator_config = None,
|
| 457 |
+
parallelism_config = None,
|
| 458 |
+
dataloader_drop_last = False,
|
| 459 |
+
dataloader_num_workers = 0,
|
| 460 |
+
dataloader_pin_memory = True,
|
| 461 |
+
dataloader_persistent_workers = False,
|
| 462 |
+
dataloader_prefetch_factor = None,
|
| 463 |
+
remove_unused_columns = True,
|
| 464 |
+
label_names = None,
|
| 465 |
+
train_sampling_strategy = 'random',
|
| 466 |
+
length_column_name = 'length',
|
| 467 |
+
ddp_find_unused_parameters = None,
|
| 468 |
+
ddp_bucket_cap_mb = None,
|
| 469 |
+
ddp_broadcast_buffers = None,
|
| 470 |
+
ddp_backend = None,
|
| 471 |
+
ddp_timeout = 1800,
|
| 472 |
+
fsdp = None,
|
| 473 |
+
fsdp_config = None,
|
| 474 |
+
deepspeed = None,
|
| 475 |
+
debug = '',
|
| 476 |
+
skip_memory_metrics = True,
|
| 477 |
+
do_train = False,
|
| 478 |
+
do_eval = False,
|
| 479 |
+
do_predict = False,
|
| 480 |
+
resume_from_checkpoint = None,
|
| 481 |
+
warmup_ratio = None,
|
| 482 |
+
logging_dir = None,
|
| 483 |
+
local_rank = -1,
|
| 484 |
+
max_length = 1024,
|
| 485 |
+
max_prompt_length = 512,
|
| 486 |
+
max_completion_length = None,
|
| 487 |
+
disable_dropout = True,
|
| 488 |
+
step_separator = '\
|
| 489 |
+
',
|
| 490 |
+
train_on_last_step_only = False,
|
| 491 |
+
dataset_num_proc = None,
|
| 492 |
+
vllm_sampling_params = None,
|
| 493 |
+
unsloth_num_chunks = -1,
|
| 494 |
+
unsloth_logit_chunk_multiplier = None,
|
| 495 |
+
unsloth_grpo_mini_batch = None,
|
| 496 |
+
max_seq_length = None,
|
| 497 |
+
**kwargs,
|
| 498 |
+
):
|
| 499 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 500 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 501 |
+
if num_train_epochs is None:
|
| 502 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 503 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 504 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 505 |
+
save_strategy = 'no'
|
| 506 |
+
import multiprocessing as _mp
|
| 507 |
+
if _mp.get_start_method() != 'fork':
|
| 508 |
+
dataset_num_proc = None
|
| 509 |
+
elif dataset_num_proc is None:
|
| 510 |
+
import psutil
|
| 511 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 512 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 513 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 514 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 515 |
+
|
| 516 |
+
super().__init__(
|
| 517 |
+
output_dir = output_dir,
|
| 518 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 519 |
+
num_train_epochs = num_train_epochs,
|
| 520 |
+
max_steps = max_steps,
|
| 521 |
+
learning_rate = learning_rate,
|
| 522 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 523 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 524 |
+
warmup_steps = warmup_steps,
|
| 525 |
+
optim = optim,
|
| 526 |
+
optim_args = optim_args,
|
| 527 |
+
weight_decay = weight_decay,
|
| 528 |
+
adam_beta1 = adam_beta1,
|
| 529 |
+
adam_beta2 = adam_beta2,
|
| 530 |
+
adam_epsilon = adam_epsilon,
|
| 531 |
+
optim_target_modules = optim_target_modules,
|
| 532 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 533 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 534 |
+
max_grad_norm = max_grad_norm,
|
| 535 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 536 |
+
bf16 = bf16,
|
| 537 |
+
fp16 = fp16,
|
| 538 |
+
bf16_full_eval = bf16_full_eval,
|
| 539 |
+
fp16_full_eval = fp16_full_eval,
|
| 540 |
+
tf32 = tf32,
|
| 541 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 542 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 543 |
+
torch_compile = torch_compile,
|
| 544 |
+
torch_compile_backend = torch_compile_backend,
|
| 545 |
+
torch_compile_mode = torch_compile_mode,
|
| 546 |
+
use_liger_kernel = use_liger_kernel,
|
| 547 |
+
liger_kernel_config = liger_kernel_config,
|
| 548 |
+
use_cache = use_cache,
|
| 549 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 550 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 551 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 552 |
+
logging_strategy = logging_strategy,
|
| 553 |
+
logging_steps = logging_steps,
|
| 554 |
+
logging_first_step = logging_first_step,
|
| 555 |
+
log_on_each_node = log_on_each_node,
|
| 556 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 557 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 558 |
+
log_level = log_level,
|
| 559 |
+
log_level_replica = log_level_replica,
|
| 560 |
+
disable_tqdm = disable_tqdm,
|
| 561 |
+
report_to = report_to,
|
| 562 |
+
run_name = run_name,
|
| 563 |
+
project = project,
|
| 564 |
+
trackio_space_id = trackio_space_id,
|
| 565 |
+
eval_strategy = eval_strategy,
|
| 566 |
+
eval_steps = eval_steps,
|
| 567 |
+
eval_delay = eval_delay,
|
| 568 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 569 |
+
prediction_loss_only = prediction_loss_only,
|
| 570 |
+
eval_on_start = eval_on_start,
|
| 571 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 572 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 573 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 574 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 575 |
+
save_only_model = save_only_model,
|
| 576 |
+
save_strategy = save_strategy,
|
| 577 |
+
save_steps = save_steps,
|
| 578 |
+
save_on_each_node = save_on_each_node,
|
| 579 |
+
save_total_limit = save_total_limit,
|
| 580 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 581 |
+
push_to_hub = push_to_hub,
|
| 582 |
+
hub_token = hub_token,
|
| 583 |
+
hub_private_repo = hub_private_repo,
|
| 584 |
+
hub_model_id = hub_model_id,
|
| 585 |
+
hub_strategy = hub_strategy,
|
| 586 |
+
hub_always_push = hub_always_push,
|
| 587 |
+
hub_revision = hub_revision,
|
| 588 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 589 |
+
metric_for_best_model = metric_for_best_model,
|
| 590 |
+
greater_is_better = greater_is_better,
|
| 591 |
+
ignore_data_skip = ignore_data_skip,
|
| 592 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 593 |
+
full_determinism = full_determinism,
|
| 594 |
+
seed = seed,
|
| 595 |
+
data_seed = data_seed,
|
| 596 |
+
use_cpu = use_cpu,
|
| 597 |
+
accelerator_config = accelerator_config,
|
| 598 |
+
parallelism_config = parallelism_config,
|
| 599 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 600 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 601 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 602 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 603 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 604 |
+
remove_unused_columns = remove_unused_columns,
|
| 605 |
+
label_names = label_names,
|
| 606 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 607 |
+
length_column_name = length_column_name,
|
| 608 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 609 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 610 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 611 |
+
ddp_backend = ddp_backend,
|
| 612 |
+
ddp_timeout = ddp_timeout,
|
| 613 |
+
fsdp = fsdp,
|
| 614 |
+
fsdp_config = fsdp_config,
|
| 615 |
+
deepspeed = deepspeed,
|
| 616 |
+
debug = debug,
|
| 617 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 618 |
+
do_train = do_train,
|
| 619 |
+
do_eval = do_eval,
|
| 620 |
+
do_predict = do_predict,
|
| 621 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 622 |
+
warmup_ratio = warmup_ratio,
|
| 623 |
+
logging_dir = logging_dir,
|
| 624 |
+
local_rank = local_rank,
|
| 625 |
+
max_length = max_length,
|
| 626 |
+
max_prompt_length = max_prompt_length,
|
| 627 |
+
max_completion_length = max_completion_length,
|
| 628 |
+
disable_dropout = disable_dropout,
|
| 629 |
+
step_separator = step_separator,
|
| 630 |
+
train_on_last_step_only = train_on_last_step_only,
|
| 631 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 632 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 633 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 634 |
+
if unsloth_grpo_mini_batch is not None:
|
| 635 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 636 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 637 |
+
else:
|
| 638 |
+
raise ValueError(
|
| 639 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 640 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 641 |
+
)
|
| 642 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 643 |
+
self.max_seq_length = max_seq_length
|
| 644 |
+
|
| 645 |
+
pass
|
| 646 |
+
|
| 647 |
+
class _UnslothPRMTrainer(BaseTrainer):
|
| 648 |
+
""""""
|
| 649 |
+
|
| 650 |
+
_tag_names = ["trl", "prm"]
|
| 651 |
+
_name = "PRM"
|
| 652 |
+
_paper = {
|
| 653 |
+
"title": "Solving math word problems with process-and outcome-based feedback",
|
| 654 |
+
"id": "2211.14275",
|
| 655 |
+
# docstyle-ignore
|
| 656 |
+
"citation": textwrap.dedent("""\
|
| 657 |
+
@article{uesato2022solving,
|
| 658 |
+
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
| 659 |
+
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
| 660 |
+
year = 2022,
|
| 661 |
+
journal = {arXiv preprint arXiv:2211.14275}
|
| 662 |
+
}"""),
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
def __init__(
|
| 666 |
+
self,
|
| 667 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 668 |
+
args: Optional[PRMConfig] = None,
|
| 669 |
+
data_collator: Optional[DataCollator] = None,
|
| 670 |
+
train_dataset: Optional[Dataset] = None,
|
| 671 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 672 |
+
processing_class: Optional[
|
| 673 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 674 |
+
] = None,
|
| 675 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 676 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 677 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 678 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
| 679 |
+
None,
|
| 680 |
+
None,
|
| 681 |
+
),
|
| 682 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 683 |
+
peft_config: Optional[dict] = None,
|
| 684 |
+
):
|
| 685 |
+
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
| 686 |
+
warnings.warn(
|
| 687 |
+
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
| 688 |
+
"it and want it to remain, please share your comments here: "
|
| 689 |
+
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
| 690 |
+
"TRL_EXPERIMENTAL_SILENCE=1."
|
| 691 |
+
)
|
| 692 |
+
if False:
|
| 693 |
+
pass
|
| 694 |
+
|
| 695 |
+
# Disable dropout in the model
|
| 696 |
+
if args.disable_dropout:
|
| 697 |
+
disable_dropout_in_model(model)
|
| 698 |
+
|
| 699 |
+
if compute_metrics is None:
|
| 700 |
+
compute_metrics = compute_accuracy
|
| 701 |
+
|
| 702 |
+
if data_collator is None:
|
| 703 |
+
if processing_class is None:
|
| 704 |
+
raise ValueError(
|
| 705 |
+
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
| 706 |
+
)
|
| 707 |
+
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
| 708 |
+
|
| 709 |
+
if "input_ids" not in train_dataset.column_names:
|
| 710 |
+
with PartialState().main_process_first():
|
| 711 |
+
fn_kwargs = {
|
| 712 |
+
"tokenizer": processing_class,
|
| 713 |
+
"step_separator": args.step_separator,
|
| 714 |
+
"max_length": args.max_length,
|
| 715 |
+
"max_prompt_length": args.max_prompt_length,
|
| 716 |
+
"max_completion_length": args.max_completion_length,
|
| 717 |
+
"train_on_last_step_only": args.train_on_last_step_only,
|
| 718 |
+
}
|
| 719 |
+
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
| 720 |
+
train_dataset = train_dataset.map(
|
| 721 |
+
self.tokenize_row,
|
| 722 |
+
fn_kwargs=train_fn_kwargs,
|
| 723 |
+
num_proc=args.dataset_num_proc,
|
| 724 |
+
remove_columns=train_dataset.features,
|
| 725 |
+
desc="Tokenizing train dataset",
|
| 726 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
| 727 |
+
{
|
| 728 |
+
"labels": features.Sequence(features.Value("int64")),
|
| 729 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
| 730 |
+
}
|
| 731 |
+
),
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
| 735 |
+
if eval_dataset is not None:
|
| 736 |
+
eval_dataset = eval_dataset.map(
|
| 737 |
+
self.tokenize_row,
|
| 738 |
+
fn_kwargs=eval_fn_kwargs,
|
| 739 |
+
num_proc=args.dataset_num_proc,
|
| 740 |
+
remove_columns=eval_dataset.features,
|
| 741 |
+
desc="Tokenizing eval dataset",
|
| 742 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
| 743 |
+
{
|
| 744 |
+
"labels": features.Sequence(features.Value("int64")),
|
| 745 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
| 746 |
+
}
|
| 747 |
+
),
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
super().__init__(
|
| 751 |
+
model=model,
|
| 752 |
+
args=args,
|
| 753 |
+
data_collator=data_collator,
|
| 754 |
+
train_dataset=train_dataset,
|
| 755 |
+
eval_dataset=eval_dataset,
|
| 756 |
+
processing_class=processing_class,
|
| 757 |
+
model_init=model_init,
|
| 758 |
+
compute_metrics=compute_metrics,
|
| 759 |
+
callbacks=callbacks,
|
| 760 |
+
optimizers=optimizers,
|
| 761 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 765 |
+
if hasattr(self.model, "add_model_tags"):
|
| 766 |
+
self.model.add_model_tags(self._tag_names)
|
| 767 |
+
|
| 768 |
+
@staticmethod
|
| 769 |
+
def tokenize_row(
|
| 770 |
+
features,
|
| 771 |
+
tokenizer,
|
| 772 |
+
step_separator,
|
| 773 |
+
max_length,
|
| 774 |
+
max_prompt_length,
|
| 775 |
+
max_completion_length,
|
| 776 |
+
train_on_last_step_only,
|
| 777 |
+
is_eval,
|
| 778 |
+
):
|
| 779 |
+
r"""
|
| 780 |
+
Tokenize a row of the dataset.
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
features (`dict[str, str]`):
|
| 784 |
+
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
| 785 |
+
tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
|
| 786 |
+
Tokenizer used to process the data.
|
| 787 |
+
step_separator (`str`):
|
| 788 |
+
Separator between steps in the completion.
|
| 789 |
+
max_length (`int` or `None`):
|
| 790 |
+
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
| 791 |
+
max_prompt_length (`int` or `None`):
|
| 792 |
+
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
| 793 |
+
max_completion_length (`int` or `None`):
|
| 794 |
+
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
| 795 |
+
train_on_last_step_only (`bool`):
|
| 796 |
+
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
| 797 |
+
token of the completion.
|
| 798 |
+
is_eval (`bool`):
|
| 799 |
+
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if
|
| 800 |
+
`train_on_last_step_only` is set to `True`.
|
| 801 |
+
|
| 802 |
+
Returns:
|
| 803 |
+
`dict[str, list[int]]`:
|
| 804 |
+
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
| 805 |
+
|
| 806 |
+
Example:
|
| 807 |
+
```python
|
| 808 |
+
>>> from transformers import AutoTokenizer
|
| 809 |
+
|
| 810 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
| 811 |
+
>>> features = {
|
| 812 |
+
... "prompt": "Which number is larger, 9.8 or 9.11?",
|
| 813 |
+
... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
|
| 814 |
+
... "labels": [True, False],
|
| 815 |
+
... }
|
| 816 |
+
>>> PRMTrainer.tokenize_row(
|
| 817 |
+
... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False
|
| 818 |
+
... )
|
| 819 |
+
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
| 820 |
+
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
| 821 |
+
```
|
| 822 |
+
"""
|
| 823 |
+
# Tokenize the prompt and completions
|
| 824 |
+
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
| 825 |
+
completions_ids = [
|
| 826 |
+
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
| 827 |
+
]
|
| 828 |
+
if train_on_last_step_only and not is_eval:
|
| 829 |
+
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
| 830 |
+
else:
|
| 831 |
+
labels = [int(label) for label in features["labels"]]
|
| 832 |
+
|
| 833 |
+
# Get the ID of the separator token and add it to the completions
|
| 834 |
+
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
| 835 |
+
completions_ids = [completion + separator_ids for completion in completions_ids]
|
| 836 |
+
|
| 837 |
+
# Create the label
|
| 838 |
+
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
| 839 |
+
|
| 840 |
+
# Join the completions and labels steps
|
| 841 |
+
completion_ids = list(chain(*completions_ids))
|
| 842 |
+
labels = list(chain(*labels))
|
| 843 |
+
|
| 844 |
+
if tokenizer.bos_token_id is not None:
|
| 845 |
+
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
| 846 |
+
|
| 847 |
+
# Truncate prompt and completion sequences
|
| 848 |
+
if max_prompt_length is not None:
|
| 849 |
+
prompt_ids = prompt_ids[-max_prompt_length:]
|
| 850 |
+
if max_completion_length is not None:
|
| 851 |
+
completion_ids = completion_ids[:max_completion_length]
|
| 852 |
+
labels = labels[:max_completion_length]
|
| 853 |
+
|
| 854 |
+
input_ids = prompt_ids + completion_ids
|
| 855 |
+
labels = [-100] * len(prompt_ids) + labels
|
| 856 |
+
|
| 857 |
+
if max_length is not None:
|
| 858 |
+
input_ids = input_ids[:max_length]
|
| 859 |
+
labels = labels[:max_length]
|
| 860 |
+
|
| 861 |
+
return {"input_ids": input_ids, "labels": labels}
|
| 862 |
+
|
| 863 |
+
# Ensure the model card is saved along with the checkpoint
|
| 864 |
+
def _save_checkpoint(self, model, trial):
|
| 865 |
+
if self.args.hub_model_id is None:
|
| 866 |
+
model_name = Path(self.args.output_dir).name
|
| 867 |
+
else:
|
| 868 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 869 |
+
self.create_model_card(model_name=model_name)
|
| 870 |
+
super()._save_checkpoint(model, trial)
|
| 871 |
+
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
| 872 |
+
"""
|
| 873 |
+
|
| 874 |
+
Initialize PRMTrainer.
|
| 875 |
+
|
| 876 |
+
Args:
|
| 877 |
+
model ([`~transformers.PreTrainedModel`]):
|
| 878 |
+
The model to train, preferably an `AutoModelForTokenClassification`.
|
| 879 |
+
args ([`PRMConfig`]):
|
| 880 |
+
The arguments to use for training.
|
| 881 |
+
data_collator ([`~transformers.DataCollator`]):
|
| 882 |
+
The data collator to use for training. If None is specified, the default data collator
|
| 883 |
+
([`~transformers.DataCollatorForTokenClassification`]) will be used which will pad the sequences to the
|
| 884 |
+
maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 885 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 886 |
+
The dataset to use for training.
|
| 887 |
+
eval_dataset ([`~datasets.Dataset`]):
|
| 888 |
+
The dataset to use for evaluation.
|
| 889 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 890 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 891 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 892 |
+
reuse the fine-tuned model.
|
| 893 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 894 |
+
The model initializer to use for training. If None is specified, the default model initializer will be
|
| 895 |
+
used.
|
| 896 |
+
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
| 897 |
+
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
|
| 898 |
+
will be used.
|
| 899 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 900 |
+
The callbacks to use for training.
|
| 901 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 902 |
+
The optimizer and scheduler to use for training.
|
| 903 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 904 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 905 |
+
peft_config (`dict`, defaults to `None`):
|
| 906 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
|
| 907 |
+
a PEFT model.
|
| 908 |
+
|
| 909 |
+
"""
|
| 910 |
+
def __init__(
|
| 911 |
+
self,
|
| 912 |
+
model = None,
|
| 913 |
+
args = None,
|
| 914 |
+
data_collator = None,
|
| 915 |
+
train_dataset = None,
|
| 916 |
+
eval_dataset = None,
|
| 917 |
+
processing_class = None,
|
| 918 |
+
model_init = None,
|
| 919 |
+
compute_metrics = None,
|
| 920 |
+
callbacks = None,
|
| 921 |
+
preprocess_logits_for_metrics = None,
|
| 922 |
+
peft_config = None,
|
| 923 |
+
**kwargs
|
| 924 |
+
):
|
| 925 |
+
if args is None: args = UnslothPRMConfig()
|
| 926 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 927 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 928 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 929 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 930 |
+
force_float32 = False
|
| 931 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 932 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 933 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 934 |
+
force_float32 = True
|
| 935 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 936 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 937 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 938 |
+
from unsloth_zoo.utils import _get_dtype
|
| 939 |
+
dtype = _get_dtype(dtype)
|
| 940 |
+
float16 = dtype == torch.float16
|
| 941 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 942 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 943 |
+
if force_float32:
|
| 944 |
+
# Forced float32 training
|
| 945 |
+
args.fp16 = False
|
| 946 |
+
args.bf16 = False
|
| 947 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 948 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 949 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 950 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 951 |
+
# Mixed precision training
|
| 952 |
+
args.fp16 = float16
|
| 953 |
+
args.bf16 = not float16
|
| 954 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 955 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 956 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 957 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 958 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 959 |
+
args.fp16 = False
|
| 960 |
+
args.bf16 = False
|
| 961 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 962 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 963 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 964 |
+
|
| 965 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 966 |
+
args.eval_strategy = 'steps'
|
| 967 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 968 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 969 |
+
if ga_steps is not None and ga_steps > 1:
|
| 970 |
+
from transformers import __version__ as transformers_version
|
| 971 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 972 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 973 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 974 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 975 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 976 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 977 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 978 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 979 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 980 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 981 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 982 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 983 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 984 |
+
if force_float32:
|
| 985 |
+
args.bf16_full_eval = False
|
| 986 |
+
args.fp16_full_eval = False
|
| 987 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 988 |
+
args.bf16_full_eval = True
|
| 989 |
+
args.fp16_full_eval = False
|
| 990 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 991 |
+
args.bf16_full_eval = args.bf16
|
| 992 |
+
args.fp16_full_eval = args.fp16
|
| 993 |
+
_output_logits = False
|
| 994 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 995 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 996 |
+
if _output_logits:
|
| 997 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 998 |
+
if model is not None:
|
| 999 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1000 |
+
if _warnings_issued is None:
|
| 1001 |
+
model.warnings_issued = {}
|
| 1002 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1003 |
+
try:
|
| 1004 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1005 |
+
except Exception:
|
| 1006 |
+
model.warnings_issued = {}
|
| 1007 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1008 |
+
pass
|
| 1009 |
+
else:
|
| 1010 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1011 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1012 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1013 |
+
max_seq_length = model.max_seq_length
|
| 1014 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1015 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1016 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1017 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1018 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1019 |
+
args.max_seq_length = model_max_seq_length
|
| 1020 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1021 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1022 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1023 |
+
if 'processing_class' in locals():
|
| 1024 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1025 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1026 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1027 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1028 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1029 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1030 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1031 |
+
__tokenizer,
|
| 1032 |
+
mlm = False,
|
| 1033 |
+
mlm_probability = 0.0,
|
| 1034 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1035 |
+
)
|
| 1036 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1037 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1038 |
+
__tokenizer,
|
| 1039 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1040 |
+
)
|
| 1041 |
+
else:
|
| 1042 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1043 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1044 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1045 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1046 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1047 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1048 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1049 |
+
__tokenizer.tokenizer,
|
| 1050 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1051 |
+
)
|
| 1052 |
+
else:
|
| 1053 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1054 |
+
__tokenizer.tokenizer,
|
| 1055 |
+
mlm = False,
|
| 1056 |
+
mlm_probability = 0.0,
|
| 1057 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1058 |
+
)
|
| 1059 |
+
other_metrics = []
|
| 1060 |
+
|
| 1061 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1062 |
+
PatchRLStatistics('prm_trainer', other_metrics)
|
| 1063 |
+
|
| 1064 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1065 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1066 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1067 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1068 |
+
args._n_gpu = 1
|
| 1069 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1070 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1071 |
+
super().__init__(
|
| 1072 |
+
model = model,
|
| 1073 |
+
args = args,
|
| 1074 |
+
data_collator = data_collator,
|
| 1075 |
+
train_dataset = train_dataset,
|
| 1076 |
+
eval_dataset = eval_dataset,
|
| 1077 |
+
processing_class = processing_class,
|
| 1078 |
+
model_init = model_init,
|
| 1079 |
+
compute_metrics = compute_metrics,
|
| 1080 |
+
callbacks = callbacks,
|
| 1081 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1082 |
+
peft_config = peft_config,**kwargs)
|
| 1083 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1084 |
+
model.for_inference()
|
| 1085 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1086 |
+
self.neftune_hook_handle.remove()
|
| 1087 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1088 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1089 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1090 |
+
pass
|
| 1091 |
+
if hasattr(self, 'accelerator'):
|
| 1092 |
+
scaler = self.accelerator.scaler
|
| 1093 |
+
current_model = model
|
| 1094 |
+
while hasattr(current_model, 'model'):
|
| 1095 |
+
current_model.accelerator_scaler = scaler
|
| 1096 |
+
current_model = current_model.model
|
| 1097 |
+
current_model.accelerator_scaler = scaler
|
| 1098 |
+
pass
|
| 1099 |
+
if hasattr(self, 'train'):
|
| 1100 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1101 |
+
pass
|
| 1102 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1103 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1104 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1105 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1106 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1107 |
+
pass
|
| 1108 |
+
|
| 1109 |
+
pass
|
unsloth_compiled_cache/UnslothRewardTrainer.py
ADDED
|
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.reward_trainer import (Any, AutoModelForSequenceClassification, AutoTokenizer, BaseTrainer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PartialState, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, RewardTrainer, TrainerCallback, Union, clone_chat_template, contextlib, dataclass, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pad, re, remove_none_values, suppress_from_pretrained_warning, torch, transformers, Any, AutoModelForSequenceClassification, AutoTokenizer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, TrainerCallback, Union, clone_chat_template, contextlib, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, logger, os, pad, re, suppress_from_pretrained_warning, torch, transformers, PreTrainedModel, logger, os, re, torch)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothRewardConfig(RewardConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`RewardTrainer`].
|
| 328 |
+
|
| 329 |
+
This class includes only the parameters that are specific to Reward training. For a full list of training
|
| 330 |
+
arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
|
| 331 |
+
class may differ from those in [`~transformers.TrainingArguments`].
|
| 332 |
+
|
| 333 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 334 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 335 |
+
command line.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
> Parameters that control the model
|
| 339 |
+
|
| 340 |
+
model_init_kwargs (`dict[str, Any]`, *optional*):
|
| 341 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
| 342 |
+
argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want
|
| 343 |
+
to include the load balancing/auxilliary loss as a part of the final loss, remember to set
|
| 344 |
+
`output_router_logits=True` in this dictionary.
|
| 345 |
+
chat_template_path (`str`, *optional*):
|
| 346 |
+
If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
|
| 347 |
+
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
|
| 348 |
+
ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
|
| 349 |
+
embedding layer is resized accordingly.
|
| 350 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 351 |
+
Whether to disable dropout in the model.
|
| 352 |
+
|
| 353 |
+
> Parameters that control the data preprocessing
|
| 354 |
+
|
| 355 |
+
dataset_num_proc (`int`, *optional*):
|
| 356 |
+
Number of processes to use for processing the dataset.
|
| 357 |
+
eos_token (`str`, *optional*):
|
| 358 |
+
Token used to indicate the end of a turn or sequence. If `None`, it defaults to
|
| 359 |
+
`processing_class.eos_token`.
|
| 360 |
+
pad_token (`str`, *optional*):
|
| 361 |
+
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
|
| 362 |
+
it falls back to `processing_class.eos_token`.
|
| 363 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 364 |
+
Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
|
| 365 |
+
exceeds this value. If `None`, no filtering is applied.
|
| 366 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 367 |
+
If set, the sequences will be padded to a multiple of this value.
|
| 368 |
+
|
| 369 |
+
> Parameters that control the training
|
| 370 |
+
|
| 371 |
+
center_rewards_coefficient (`float`, *optional*):
|
| 372 |
+
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
| 373 |
+
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
| 374 |
+
activation_offloading (`bool`, *optional*, defaults to `False`):
|
| 375 |
+
Whether to offload the activations to the CPU.
|
| 376 |
+
|
| 377 |
+
"""
|
| 378 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 379 |
+
default = None,
|
| 380 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 381 |
+
)
|
| 382 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 383 |
+
default = -1,
|
| 384 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 385 |
+
)
|
| 386 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 387 |
+
default = None,
|
| 388 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 389 |
+
)
|
| 390 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 391 |
+
default = None,
|
| 392 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 393 |
+
)
|
| 394 |
+
max_seq_length : Optional[int] = field(
|
| 395 |
+
default = None,
|
| 396 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 397 |
+
)
|
| 398 |
+
def __init__(
|
| 399 |
+
self,
|
| 400 |
+
output_dir = None,
|
| 401 |
+
per_device_train_batch_size = 4,
|
| 402 |
+
num_train_epochs = 3.0,
|
| 403 |
+
max_steps = -1,
|
| 404 |
+
learning_rate = 5e-05,
|
| 405 |
+
lr_scheduler_type = 'linear',
|
| 406 |
+
lr_scheduler_kwargs = None,
|
| 407 |
+
warmup_steps = 0.1,
|
| 408 |
+
optim = 'adamw_8bit',
|
| 409 |
+
optim_args = None,
|
| 410 |
+
weight_decay = 0.01,
|
| 411 |
+
adam_beta1 = 0.9,
|
| 412 |
+
adam_beta2 = 0.999,
|
| 413 |
+
adam_epsilon = 1e-08,
|
| 414 |
+
optim_target_modules = None,
|
| 415 |
+
gradient_accumulation_steps = 2,
|
| 416 |
+
average_tokens_across_devices = True,
|
| 417 |
+
max_grad_norm = 1.0,
|
| 418 |
+
label_smoothing_factor = 0.0,
|
| 419 |
+
bf16 = False,
|
| 420 |
+
fp16 = False,
|
| 421 |
+
bf16_full_eval = False,
|
| 422 |
+
fp16_full_eval = False,
|
| 423 |
+
tf32 = None,
|
| 424 |
+
gradient_checkpointing = True,
|
| 425 |
+
gradient_checkpointing_kwargs = None,
|
| 426 |
+
torch_compile = False,
|
| 427 |
+
torch_compile_backend = None,
|
| 428 |
+
torch_compile_mode = None,
|
| 429 |
+
use_liger_kernel = False,
|
| 430 |
+
liger_kernel_config = None,
|
| 431 |
+
use_cache = False,
|
| 432 |
+
neftune_noise_alpha = None,
|
| 433 |
+
torch_empty_cache_steps = 250,
|
| 434 |
+
auto_find_batch_size = False,
|
| 435 |
+
logging_strategy = 'steps',
|
| 436 |
+
logging_steps = 1,
|
| 437 |
+
logging_first_step = False,
|
| 438 |
+
log_on_each_node = True,
|
| 439 |
+
logging_nan_inf_filter = False,
|
| 440 |
+
include_num_input_tokens_seen = False,
|
| 441 |
+
log_level = 'passive',
|
| 442 |
+
log_level_replica = 'warning',
|
| 443 |
+
disable_tqdm = None,
|
| 444 |
+
report_to = 'none',
|
| 445 |
+
run_name = None,
|
| 446 |
+
project = 'huggingface',
|
| 447 |
+
trackio_space_id = 'trackio',
|
| 448 |
+
eval_strategy = 'no',
|
| 449 |
+
eval_steps = None,
|
| 450 |
+
eval_delay = 0,
|
| 451 |
+
per_device_eval_batch_size = 4,
|
| 452 |
+
prediction_loss_only = False,
|
| 453 |
+
eval_on_start = False,
|
| 454 |
+
eval_do_concat_batches = True,
|
| 455 |
+
eval_use_gather_object = False,
|
| 456 |
+
eval_accumulation_steps = 2,
|
| 457 |
+
batch_eval_metrics = False,
|
| 458 |
+
save_only_model = False,
|
| 459 |
+
save_strategy = 'steps',
|
| 460 |
+
save_steps = 500,
|
| 461 |
+
save_on_each_node = False,
|
| 462 |
+
save_total_limit = None,
|
| 463 |
+
enable_jit_checkpoint = False,
|
| 464 |
+
push_to_hub = False,
|
| 465 |
+
hub_token = None,
|
| 466 |
+
hub_private_repo = None,
|
| 467 |
+
hub_model_id = None,
|
| 468 |
+
hub_strategy = 'every_save',
|
| 469 |
+
hub_always_push = False,
|
| 470 |
+
hub_revision = None,
|
| 471 |
+
load_best_model_at_end = False,
|
| 472 |
+
metric_for_best_model = None,
|
| 473 |
+
greater_is_better = None,
|
| 474 |
+
ignore_data_skip = False,
|
| 475 |
+
restore_callback_states_from_checkpoint = False,
|
| 476 |
+
full_determinism = False,
|
| 477 |
+
seed = 3407,
|
| 478 |
+
data_seed = 3407,
|
| 479 |
+
use_cpu = False,
|
| 480 |
+
accelerator_config = None,
|
| 481 |
+
parallelism_config = None,
|
| 482 |
+
dataloader_drop_last = False,
|
| 483 |
+
dataloader_num_workers = 0,
|
| 484 |
+
dataloader_pin_memory = True,
|
| 485 |
+
dataloader_persistent_workers = False,
|
| 486 |
+
dataloader_prefetch_factor = None,
|
| 487 |
+
remove_unused_columns = True,
|
| 488 |
+
label_names = None,
|
| 489 |
+
train_sampling_strategy = 'random',
|
| 490 |
+
length_column_name = 'length',
|
| 491 |
+
ddp_find_unused_parameters = None,
|
| 492 |
+
ddp_bucket_cap_mb = None,
|
| 493 |
+
ddp_broadcast_buffers = None,
|
| 494 |
+
ddp_backend = None,
|
| 495 |
+
ddp_timeout = 1800,
|
| 496 |
+
fsdp = None,
|
| 497 |
+
fsdp_config = None,
|
| 498 |
+
deepspeed = None,
|
| 499 |
+
debug = '',
|
| 500 |
+
skip_memory_metrics = True,
|
| 501 |
+
do_train = False,
|
| 502 |
+
do_eval = False,
|
| 503 |
+
do_predict = False,
|
| 504 |
+
resume_from_checkpoint = None,
|
| 505 |
+
warmup_ratio = None,
|
| 506 |
+
logging_dir = None,
|
| 507 |
+
local_rank = -1,
|
| 508 |
+
model_init_kwargs = None,
|
| 509 |
+
chat_template_path = None,
|
| 510 |
+
disable_dropout = True,
|
| 511 |
+
dataset_num_proc = None,
|
| 512 |
+
eos_token = None,
|
| 513 |
+
pad_token = None,
|
| 514 |
+
max_length = 1024,
|
| 515 |
+
pad_to_multiple_of = None,
|
| 516 |
+
center_rewards_coefficient = None,
|
| 517 |
+
activation_offloading = False,
|
| 518 |
+
vllm_sampling_params = None,
|
| 519 |
+
unsloth_num_chunks = -1,
|
| 520 |
+
unsloth_logit_chunk_multiplier = None,
|
| 521 |
+
unsloth_grpo_mini_batch = None,
|
| 522 |
+
max_seq_length = None,
|
| 523 |
+
**kwargs,
|
| 524 |
+
):
|
| 525 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 526 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 527 |
+
if num_train_epochs is None:
|
| 528 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 529 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 530 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 531 |
+
save_strategy = 'no'
|
| 532 |
+
import multiprocessing as _mp
|
| 533 |
+
if _mp.get_start_method() != 'fork':
|
| 534 |
+
dataset_num_proc = None
|
| 535 |
+
elif dataset_num_proc is None:
|
| 536 |
+
import psutil
|
| 537 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 538 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 539 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 540 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 541 |
+
if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
|
| 542 |
+
from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
|
| 543 |
+
if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
|
| 544 |
+
from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
|
| 545 |
+
pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
super().__init__(
|
| 549 |
+
output_dir = output_dir,
|
| 550 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 551 |
+
num_train_epochs = num_train_epochs,
|
| 552 |
+
max_steps = max_steps,
|
| 553 |
+
learning_rate = learning_rate,
|
| 554 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 555 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 556 |
+
warmup_steps = warmup_steps,
|
| 557 |
+
optim = optim,
|
| 558 |
+
optim_args = optim_args,
|
| 559 |
+
weight_decay = weight_decay,
|
| 560 |
+
adam_beta1 = adam_beta1,
|
| 561 |
+
adam_beta2 = adam_beta2,
|
| 562 |
+
adam_epsilon = adam_epsilon,
|
| 563 |
+
optim_target_modules = optim_target_modules,
|
| 564 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 565 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 566 |
+
max_grad_norm = max_grad_norm,
|
| 567 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 568 |
+
bf16 = bf16,
|
| 569 |
+
fp16 = fp16,
|
| 570 |
+
bf16_full_eval = bf16_full_eval,
|
| 571 |
+
fp16_full_eval = fp16_full_eval,
|
| 572 |
+
tf32 = tf32,
|
| 573 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 574 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 575 |
+
torch_compile = torch_compile,
|
| 576 |
+
torch_compile_backend = torch_compile_backend,
|
| 577 |
+
torch_compile_mode = torch_compile_mode,
|
| 578 |
+
use_liger_kernel = use_liger_kernel,
|
| 579 |
+
liger_kernel_config = liger_kernel_config,
|
| 580 |
+
use_cache = use_cache,
|
| 581 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 582 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 583 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 584 |
+
logging_strategy = logging_strategy,
|
| 585 |
+
logging_steps = logging_steps,
|
| 586 |
+
logging_first_step = logging_first_step,
|
| 587 |
+
log_on_each_node = log_on_each_node,
|
| 588 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 589 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 590 |
+
log_level = log_level,
|
| 591 |
+
log_level_replica = log_level_replica,
|
| 592 |
+
disable_tqdm = disable_tqdm,
|
| 593 |
+
report_to = report_to,
|
| 594 |
+
run_name = run_name,
|
| 595 |
+
project = project,
|
| 596 |
+
trackio_space_id = trackio_space_id,
|
| 597 |
+
eval_strategy = eval_strategy,
|
| 598 |
+
eval_steps = eval_steps,
|
| 599 |
+
eval_delay = eval_delay,
|
| 600 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 601 |
+
prediction_loss_only = prediction_loss_only,
|
| 602 |
+
eval_on_start = eval_on_start,
|
| 603 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 604 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 605 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 606 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 607 |
+
save_only_model = save_only_model,
|
| 608 |
+
save_strategy = save_strategy,
|
| 609 |
+
save_steps = save_steps,
|
| 610 |
+
save_on_each_node = save_on_each_node,
|
| 611 |
+
save_total_limit = save_total_limit,
|
| 612 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 613 |
+
push_to_hub = push_to_hub,
|
| 614 |
+
hub_token = hub_token,
|
| 615 |
+
hub_private_repo = hub_private_repo,
|
| 616 |
+
hub_model_id = hub_model_id,
|
| 617 |
+
hub_strategy = hub_strategy,
|
| 618 |
+
hub_always_push = hub_always_push,
|
| 619 |
+
hub_revision = hub_revision,
|
| 620 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 621 |
+
metric_for_best_model = metric_for_best_model,
|
| 622 |
+
greater_is_better = greater_is_better,
|
| 623 |
+
ignore_data_skip = ignore_data_skip,
|
| 624 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 625 |
+
full_determinism = full_determinism,
|
| 626 |
+
seed = seed,
|
| 627 |
+
data_seed = data_seed,
|
| 628 |
+
use_cpu = use_cpu,
|
| 629 |
+
accelerator_config = accelerator_config,
|
| 630 |
+
parallelism_config = parallelism_config,
|
| 631 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 632 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 633 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 634 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 635 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 636 |
+
remove_unused_columns = remove_unused_columns,
|
| 637 |
+
label_names = label_names,
|
| 638 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 639 |
+
length_column_name = length_column_name,
|
| 640 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 641 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 642 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 643 |
+
ddp_backend = ddp_backend,
|
| 644 |
+
ddp_timeout = ddp_timeout,
|
| 645 |
+
fsdp = fsdp,
|
| 646 |
+
fsdp_config = fsdp_config,
|
| 647 |
+
deepspeed = deepspeed,
|
| 648 |
+
debug = debug,
|
| 649 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 650 |
+
do_train = do_train,
|
| 651 |
+
do_eval = do_eval,
|
| 652 |
+
do_predict = do_predict,
|
| 653 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 654 |
+
warmup_ratio = warmup_ratio,
|
| 655 |
+
logging_dir = logging_dir,
|
| 656 |
+
local_rank = local_rank,
|
| 657 |
+
model_init_kwargs = model_init_kwargs,
|
| 658 |
+
chat_template_path = chat_template_path,
|
| 659 |
+
disable_dropout = disable_dropout,
|
| 660 |
+
dataset_num_proc = dataset_num_proc,
|
| 661 |
+
eos_token = eos_token,
|
| 662 |
+
pad_token = pad_token,
|
| 663 |
+
max_length = max_length,
|
| 664 |
+
pad_to_multiple_of = pad_to_multiple_of,
|
| 665 |
+
center_rewards_coefficient = center_rewards_coefficient,
|
| 666 |
+
activation_offloading = activation_offloading,**kwargs)
|
| 667 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 668 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 669 |
+
if unsloth_grpo_mini_batch is not None:
|
| 670 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 671 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 672 |
+
else:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 675 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 676 |
+
)
|
| 677 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 678 |
+
self.max_seq_length = max_seq_length
|
| 679 |
+
|
| 680 |
+
pass
|
| 681 |
+
|
| 682 |
+
class _UnslothRewardTrainer(BaseTrainer):
|
| 683 |
+
""""""
|
| 684 |
+
|
| 685 |
+
_tag_names = ["trl", "reward-trainer"]
|
| 686 |
+
_name = "Reward"
|
| 687 |
+
_template_file = "rm_model_card.md"
|
| 688 |
+
|
| 689 |
+
def __init__(
|
| 690 |
+
self,
|
| 691 |
+
model: Union[str, PreTrainedModel],
|
| 692 |
+
args: Optional[RewardConfig] = None,
|
| 693 |
+
data_collator: Optional[DataCollator] = None,
|
| 694 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 695 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 696 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 697 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 698 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 699 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 700 |
+
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
| 701 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 702 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 703 |
+
):
|
| 704 |
+
# Args
|
| 705 |
+
if args is None:
|
| 706 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 707 |
+
model_name = model_name.split("/")[-1]
|
| 708 |
+
args = RewardConfig(f"{model_name}-Reward")
|
| 709 |
+
|
| 710 |
+
# Model
|
| 711 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 712 |
+
if isinstance(model, str):
|
| 713 |
+
model_id = model
|
| 714 |
+
dtype = model_init_kwargs.get("dtype")
|
| 715 |
+
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
|
| 716 |
+
pass # dtype is already a torch.dtype or "auto" or None
|
| 717 |
+
elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
|
| 718 |
+
model_init_kwargs["dtype"] = getattr(torch, dtype)
|
| 719 |
+
else:
|
| 720 |
+
raise ValueError(
|
| 721 |
+
"Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
|
| 722 |
+
f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
|
| 723 |
+
)
|
| 724 |
+
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
|
| 725 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
|
| 726 |
+
else:
|
| 727 |
+
model_id = model.config._name_or_path
|
| 728 |
+
if args.model_init_kwargs is not None:
|
| 729 |
+
logger.warning(
|
| 730 |
+
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
|
| 731 |
+
"The `model_init_kwargs` will be ignored."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# Processing class
|
| 735 |
+
if processing_class is None:
|
| 736 |
+
processing_class = AutoTokenizer.from_pretrained(model_id)
|
| 737 |
+
|
| 738 |
+
# Handle pad token for processors or tokenizers
|
| 739 |
+
if args.eos_token is not None:
|
| 740 |
+
eos_token = args.eos_token
|
| 741 |
+
eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
|
| 742 |
+
if eos_token_id is None:
|
| 743 |
+
raise ValueError(
|
| 744 |
+
f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
|
| 745 |
+
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
|
| 746 |
+
"in the vocabulary before using it as an EOS token."
|
| 747 |
+
)
|
| 748 |
+
processing_class.eos_token_id = eos_token_id
|
| 749 |
+
|
| 750 |
+
if args.chat_template_path is not None:
|
| 751 |
+
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
|
| 752 |
+
with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
|
| 753 |
+
processing_class.chat_template = chat_template_file.read()
|
| 754 |
+
added_tokens = []
|
| 755 |
+
else:
|
| 756 |
+
model, processing_class, added_tokens = clone_chat_template(
|
| 757 |
+
model, processing_class, args.chat_template_path
|
| 758 |
+
)
|
| 759 |
+
else:
|
| 760 |
+
added_tokens = []
|
| 761 |
+
|
| 762 |
+
# PEFT configuration and model wrapping
|
| 763 |
+
if False:
|
| 764 |
+
if added_tokens:
|
| 765 |
+
# Ensure that the added tokens are trainable
|
| 766 |
+
if peft_config.trainable_token_indices is None:
|
| 767 |
+
peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
|
| 768 |
+
elif "embed_tokens" not in peft_config.trainable_token_indices:
|
| 769 |
+
peft_config.trainable_token_indices["embed_tokens"] = added_tokens
|
| 770 |
+
else:
|
| 771 |
+
peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
|
| 772 |
+
|
| 773 |
+
# Ensure that the lm_head is trainable
|
| 774 |
+
if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
|
| 775 |
+
logger.warning(
|
| 776 |
+
"Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
|
| 777 |
+
"`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
|
| 778 |
+
"tokens, leading to degraded generation quality. To fix this, add "
|
| 779 |
+
"`modules_to_save=['lm_head']` to your PEFT configuration."
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
if peft_config.modules_to_save is None:
|
| 783 |
+
peft_config.modules_to_save = ["lm_head"]
|
| 784 |
+
else:
|
| 785 |
+
peft_config.modules_to_save.append("lm_head")
|
| 786 |
+
|
| 787 |
+
if False:
|
| 788 |
+
pass
|
| 789 |
+
|
| 790 |
+
# Disable dropout in the model
|
| 791 |
+
if args.disable_dropout:
|
| 792 |
+
disable_dropout_in_model(model)
|
| 793 |
+
|
| 794 |
+
# Pad token [needed for SequenceClassification models]
|
| 795 |
+
# If not provided, use the one from the processing class or the eos token if the processing class does not have
|
| 796 |
+
# a pad token.
|
| 797 |
+
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
|
| 798 |
+
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
|
| 799 |
+
if pad_token_id is None:
|
| 800 |
+
raise ValueError(
|
| 801 |
+
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
|
| 802 |
+
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
| 803 |
+
"in the vocabulary before using it as a padding token."
|
| 804 |
+
)
|
| 805 |
+
model.config.pad_token_id = pad_token_id
|
| 806 |
+
processing_class.pad_token_id = pad_token_id
|
| 807 |
+
|
| 808 |
+
# Data collator
|
| 809 |
+
if data_collator is None:
|
| 810 |
+
data_collator = DataCollatorForPreference(
|
| 811 |
+
pad_token_id=pad_token_id,
|
| 812 |
+
pad_to_multiple_of=args.pad_to_multiple_of,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
# Dataset
|
| 816 |
+
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
|
| 817 |
+
if eval_dataset is not None:
|
| 818 |
+
if isinstance(eval_dataset, dict):
|
| 819 |
+
eval_dataset = {
|
| 820 |
+
key: self._prepare_dataset(dataset, processing_class, args, key)
|
| 821 |
+
for key, dataset in eval_dataset.items()
|
| 822 |
+
}
|
| 823 |
+
else:
|
| 824 |
+
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
|
| 825 |
+
|
| 826 |
+
# Initialize the metrics
|
| 827 |
+
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
| 828 |
+
self._total_train_tokens = 0
|
| 829 |
+
|
| 830 |
+
# Initialize the Trainer. Parent class will handle:
|
| 831 |
+
# - DeepSpeed configuration [through create_accelerator_and_postprocess]
|
| 832 |
+
# - FSDP setup
|
| 833 |
+
# - Distributed training setup
|
| 834 |
+
# - Optimizer and scheduler creation
|
| 835 |
+
|
| 836 |
+
super().__init__(
|
| 837 |
+
model=model,
|
| 838 |
+
args=args,
|
| 839 |
+
data_collator=data_collator,
|
| 840 |
+
train_dataset=train_dataset,
|
| 841 |
+
eval_dataset=eval_dataset,
|
| 842 |
+
processing_class=processing_class,
|
| 843 |
+
compute_metrics=compute_metrics,
|
| 844 |
+
callbacks=callbacks,
|
| 845 |
+
optimizers=optimizers,
|
| 846 |
+
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
|
| 847 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
# During evaluation, Trainer calls compute_loss[] only if can_return_loss is True and label_names is empty.
|
| 851 |
+
self.can_return_loss = True
|
| 852 |
+
self.label_names = []
|
| 853 |
+
|
| 854 |
+
# Initialize activation offloading context
|
| 855 |
+
if self.args.activation_offloading:
|
| 856 |
+
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
|
| 857 |
+
else:
|
| 858 |
+
self.maybe_activation_offload_context = contextlib.nullcontext()
|
| 859 |
+
|
| 860 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 861 |
+
if hasattr(self.model, "add_model_tags"):
|
| 862 |
+
self.model.add_model_tags(self._tag_names)
|
| 863 |
+
|
| 864 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 865 |
+
|
| 866 |
+
def _prepare_dataset(
|
| 867 |
+
self,
|
| 868 |
+
dataset: Union[Dataset, IterableDataset],
|
| 869 |
+
processing_class: PreTrainedTokenizerBase,
|
| 870 |
+
args: RewardConfig,
|
| 871 |
+
dataset_name: str,
|
| 872 |
+
) -> Union[Dataset, IterableDataset]:
|
| 873 |
+
# Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
|
| 874 |
+
# sampled data.
|
| 875 |
+
if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform`
|
| 876 |
+
dataset = dataset.with_transform(remove_none_values)
|
| 877 |
+
|
| 878 |
+
# If the dataset is already preprocessed (tokenized), skip the processing steps.
|
| 879 |
+
column_names = list(next(iter(dataset)).keys())
|
| 880 |
+
is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names
|
| 881 |
+
|
| 882 |
+
# Build the kwargs for the `map` function
|
| 883 |
+
map_kwargs = {}
|
| 884 |
+
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
| 885 |
+
map_kwargs["num_proc"] = args.dataset_num_proc
|
| 886 |
+
|
| 887 |
+
with PartialState().main_process_first():
|
| 888 |
+
if not is_processed:
|
| 889 |
+
# Add EOS token to the end of the sequences if needed
|
| 890 |
+
first_example = next(iter(dataset))
|
| 891 |
+
if not is_conversational(first_example):
|
| 892 |
+
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
| 893 |
+
map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
|
| 894 |
+
|
| 895 |
+
def add_eos(example, eos_token):
|
| 896 |
+
if not example["chosen"].endswith(eos_token):
|
| 897 |
+
example["chosen"] = example["chosen"] + eos_token
|
| 898 |
+
if "rejected" in example and not example["rejected"].endswith(eos_token):
|
| 899 |
+
example["rejected"] = example["rejected"] + eos_token
|
| 900 |
+
return example
|
| 901 |
+
|
| 902 |
+
dataset = dataset.map(
|
| 903 |
+
add_eos,
|
| 904 |
+
fn_kwargs={"eos_token": processing_class.eos_token},
|
| 905 |
+
**map_kwargs,
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
# Tokenize the dataset
|
| 909 |
+
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
| 910 |
+
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
| 911 |
+
|
| 912 |
+
def tokenize_fn(example, processing_class):
|
| 913 |
+
if "prompt" in example: # explicit prompt case
|
| 914 |
+
example["chosen"] = example["prompt"] + example["chosen"]
|
| 915 |
+
example["rejected"] = example["prompt"] + example["rejected"]
|
| 916 |
+
|
| 917 |
+
if is_conversational(example):
|
| 918 |
+
chosen_input_ids = processing_class.apply_chat_template(
|
| 919 |
+
example["chosen"],
|
| 920 |
+
tools=example.get("tools"),
|
| 921 |
+
**example.get("chat_template_kwargs", {}),
|
| 922 |
+
)
|
| 923 |
+
rejected_input_ids = processing_class.apply_chat_template(
|
| 924 |
+
example["rejected"],
|
| 925 |
+
tools=example.get("tools"),
|
| 926 |
+
**example.get("chat_template_kwargs", {}),
|
| 927 |
+
)
|
| 928 |
+
output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids}
|
| 929 |
+
else:
|
| 930 |
+
output = {
|
| 931 |
+
"chosen_input_ids": processing_class(text=example["chosen"])["input_ids"],
|
| 932 |
+
"rejected_input_ids": processing_class(text=example["rejected"])["input_ids"],
|
| 933 |
+
}
|
| 934 |
+
return output
|
| 935 |
+
|
| 936 |
+
dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
|
| 937 |
+
|
| 938 |
+
# Filter samples that are longer than `max_length`
|
| 939 |
+
if args.max_length is not None:
|
| 940 |
+
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
| 941 |
+
map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens"
|
| 942 |
+
dataset = dataset.filter(
|
| 943 |
+
lambda example: len(example["chosen_input_ids"]) <= args.max_length
|
| 944 |
+
and len(example["rejected_input_ids"]) <= args.max_length,
|
| 945 |
+
**map_kwargs,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
return dataset
|
| 949 |
+
|
| 950 |
+
def _set_signature_columns_if_needed(self):
|
| 951 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 952 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
|
| 953 |
+
# and "attention_mask").
|
| 954 |
+
if self._signature_columns is None:
|
| 955 |
+
self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"]
|
| 956 |
+
|
| 957 |
+
def compute_loss(
|
| 958 |
+
self,
|
| 959 |
+
model: nn.Module,
|
| 960 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 961 |
+
return_outputs: bool = False,
|
| 962 |
+
num_items_in_batch: Optional[torch.Tensor] = None,
|
| 963 |
+
):
|
| 964 |
+
"""
|
| 965 |
+
Compute training loss and additionally compute token accuracies
|
| 966 |
+
"""
|
| 967 |
+
mode = "train" if self.model.training else "eval"
|
| 968 |
+
|
| 969 |
+
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
|
| 970 |
+
inputs["use_cache"] = False
|
| 971 |
+
outputs = model(**inputs)
|
| 972 |
+
|
| 973 |
+
# Split the rewards into chosen and rejected
|
| 974 |
+
rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2)
|
| 975 |
+
|
| 976 |
+
# Calculate loss, optionally modulate with margin
|
| 977 |
+
if "margin" in inputs:
|
| 978 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
| 979 |
+
else:
|
| 980 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
| 981 |
+
|
| 982 |
+
if self.args.center_rewards_coefficient is not None:
|
| 983 |
+
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
| 984 |
+
|
| 985 |
+
if mode == "train":
|
| 986 |
+
num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
|
| 987 |
+
self._total_train_tokens += num_tokens_in_batch
|
| 988 |
+
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
|
| 989 |
+
|
| 990 |
+
# Compute min, mean, max, accuracy and margin
|
| 991 |
+
with torch.no_grad():
|
| 992 |
+
all_rewards = self.accelerator.gather(outputs.logits)
|
| 993 |
+
self._metrics[mode]["min_reward"].append(all_rewards.min().item())
|
| 994 |
+
self._metrics[mode]["mean_reward"].append(all_rewards.mean().item())
|
| 995 |
+
self._metrics[mode]["max_reward"].append(all_rewards.max().item())
|
| 996 |
+
|
| 997 |
+
mean_accuracy = (rewards_chosen > rewards_rejected).float().mean()
|
| 998 |
+
mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item()
|
| 999 |
+
self._metrics[mode]["accuracy"].append(mean_accuracy)
|
| 1000 |
+
|
| 1001 |
+
mean_margin = (rewards_chosen - rewards_rejected).mean()
|
| 1002 |
+
mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean()
|
| 1003 |
+
self._metrics[mode]["margin"].append(mean_margin.item())
|
| 1004 |
+
|
| 1005 |
+
return (loss, outputs) if return_outputs else loss
|
| 1006 |
+
|
| 1007 |
+
# Override training step to add activation offloading context.
|
| 1008 |
+
def training_step(self, *args, **kwargs):
|
| 1009 |
+
with self.maybe_activation_offload_context:
|
| 1010 |
+
return super().training_step(*args, **kwargs)
|
| 1011 |
+
|
| 1012 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1013 |
+
mode = "train" if self.model.training else "eval"
|
| 1014 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
| 1015 |
+
|
| 1016 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 1017 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 1018 |
+
if mode == "eval":
|
| 1019 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 1020 |
+
|
| 1021 |
+
logs.update(metrics)
|
| 1022 |
+
super().log(logs, start_time)
|
| 1023 |
+
self._metrics[mode].clear()
|
| 1024 |
+
|
| 1025 |
+
# Ensure the model card is saved along with the checkpoint
|
| 1026 |
+
def _save_checkpoint(self, model, trial):
|
| 1027 |
+
if self.args.hub_model_id is None:
|
| 1028 |
+
model_name = Path(self.args.output_dir).name
|
| 1029 |
+
else:
|
| 1030 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 1031 |
+
self.create_model_card(model_name=model_name)
|
| 1032 |
+
super()._save_checkpoint(model, trial)
|
| 1033 |
+
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
| 1034 |
+
"""
|
| 1035 |
+
|
| 1036 |
+
Trainer for Outcome-supervised Reward Models (ORM).
|
| 1037 |
+
|
| 1038 |
+
This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
|
| 1039 |
+
|
| 1040 |
+
Example:
|
| 1041 |
+
|
| 1042 |
+
```python
|
| 1043 |
+
from trl import RewardTrainer
|
| 1044 |
+
from datasets import load_dataset
|
| 1045 |
+
|
| 1046 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
| 1047 |
+
|
| 1048 |
+
trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset)
|
| 1049 |
+
trainer.train()
|
| 1050 |
+
```
|
| 1051 |
+
|
| 1052 |
+
Args:
|
| 1053 |
+
model (`Union[str, PreTrainedModel]`):
|
| 1054 |
+
Model to be trained. Can be either:
|
| 1055 |
+
|
| 1056 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
| 1057 |
+
path to a *directory* containing model weights saved using
|
| 1058 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
| 1059 |
+
using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
|
| 1060 |
+
`args.model_init_kwargs`.
|
| 1061 |
+
- A sequence classification [`~transformers.PreTrainedModel`] object.
|
| 1062 |
+
args ([`RewardConfig`], *optional*):
|
| 1063 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 1064 |
+
data_collator ([`~transformers.DataCollator`], *optional*):
|
| 1065 |
+
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
| 1066 |
+
Will default to [`~trainer.reward_trainer.DataCollatorForPreference`].
|
| 1067 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 1068 |
+
Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and
|
| 1069 |
+
explicit prompt). The format of the samples can be either:
|
| 1070 |
+
|
| 1071 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 1072 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 1073 |
+
and content).
|
| 1074 |
+
|
| 1075 |
+
The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and
|
| 1076 |
+
`rejected_input_ids` fields.
|
| 1077 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 1078 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 1079 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*):
|
| 1080 |
+
Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with
|
| 1081 |
+
[`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be
|
| 1082 |
+
set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the
|
| 1083 |
+
default.
|
| 1084 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1085 |
+
The function that will be used to compute metrics at evaluation. Must take a
|
| 1086 |
+
[`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
|
| 1087 |
+
[`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a
|
| 1088 |
+
boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the
|
| 1089 |
+
function needs to calculate and return the global summary statistics rather than accumulating the
|
| 1090 |
+
batch-level statistics.
|
| 1091 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
|
| 1092 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
|
| 1093 |
+
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 1094 |
+
|
| 1095 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 1096 |
+
method.
|
| 1097 |
+
optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
|
| 1098 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
|
| 1099 |
+
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
|
| 1100 |
+
optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
|
| 1101 |
+
A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
|
| 1102 |
+
`args`. Incompatible with the `optimizers` argument.
|
| 1103 |
+
|
| 1104 |
+
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
|
| 1105 |
+
initializing the Trainer.
|
| 1106 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
|
| 1107 |
+
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
| 1108 |
+
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
| 1109 |
+
by this function will be reflected in the predictions received by `compute_metrics`.
|
| 1110 |
+
|
| 1111 |
+
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
| 1112 |
+
peft_config ([`~peft.PeftConfig`], *optional*):
|
| 1113 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
|
| 1114 |
+
model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
|
| 1115 |
+
to ensure that the reward head is properly trained.
|
| 1116 |
+
|
| 1117 |
+
"""
|
| 1118 |
+
def __init__(
|
| 1119 |
+
self,
|
| 1120 |
+
model,
|
| 1121 |
+
args = None,
|
| 1122 |
+
data_collator = None,
|
| 1123 |
+
train_dataset = None,
|
| 1124 |
+
eval_dataset = None,
|
| 1125 |
+
processing_class = None,
|
| 1126 |
+
compute_metrics = None,
|
| 1127 |
+
callbacks = None,
|
| 1128 |
+
optimizer_cls_and_kwargs = None,
|
| 1129 |
+
preprocess_logits_for_metrics = None,
|
| 1130 |
+
peft_config = None,
|
| 1131 |
+
**kwargs
|
| 1132 |
+
):
|
| 1133 |
+
if args is None: args = UnslothRewardConfig()
|
| 1134 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1135 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1136 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1137 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1138 |
+
force_float32 = False
|
| 1139 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1140 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1141 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1142 |
+
force_float32 = True
|
| 1143 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1144 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1145 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1146 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1147 |
+
dtype = _get_dtype(dtype)
|
| 1148 |
+
float16 = dtype == torch.float16
|
| 1149 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1150 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1151 |
+
if force_float32:
|
| 1152 |
+
# Forced float32 training
|
| 1153 |
+
args.fp16 = False
|
| 1154 |
+
args.bf16 = False
|
| 1155 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1156 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1157 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1158 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1159 |
+
# Mixed precision training
|
| 1160 |
+
args.fp16 = float16
|
| 1161 |
+
args.bf16 = not float16
|
| 1162 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1163 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1164 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1165 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1166 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1167 |
+
args.fp16 = False
|
| 1168 |
+
args.bf16 = False
|
| 1169 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1170 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1171 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1172 |
+
|
| 1173 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1174 |
+
args.eval_strategy = 'steps'
|
| 1175 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1176 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1177 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1178 |
+
from transformers import __version__ as transformers_version
|
| 1179 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1180 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1181 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1182 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1183 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1184 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1185 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1186 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1187 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1188 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1189 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1190 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1191 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1192 |
+
if force_float32:
|
| 1193 |
+
args.bf16_full_eval = False
|
| 1194 |
+
args.fp16_full_eval = False
|
| 1195 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1196 |
+
args.bf16_full_eval = True
|
| 1197 |
+
args.fp16_full_eval = False
|
| 1198 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1199 |
+
args.bf16_full_eval = args.bf16
|
| 1200 |
+
args.fp16_full_eval = args.fp16
|
| 1201 |
+
_output_logits = False
|
| 1202 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1203 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1204 |
+
if _output_logits:
|
| 1205 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1206 |
+
if model is not None:
|
| 1207 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1208 |
+
if _warnings_issued is None:
|
| 1209 |
+
model.warnings_issued = {}
|
| 1210 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1211 |
+
try:
|
| 1212 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1213 |
+
except Exception:
|
| 1214 |
+
model.warnings_issued = {}
|
| 1215 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1216 |
+
pass
|
| 1217 |
+
else:
|
| 1218 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1219 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1220 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1221 |
+
max_seq_length = model.max_seq_length
|
| 1222 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1223 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1224 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1225 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1226 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1227 |
+
args.max_seq_length = model_max_seq_length
|
| 1228 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1229 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1230 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1231 |
+
if 'processing_class' in locals():
|
| 1232 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1233 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1234 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1235 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1236 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1237 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1238 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1239 |
+
__tokenizer,
|
| 1240 |
+
mlm = False,
|
| 1241 |
+
mlm_probability = 0.0,
|
| 1242 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1243 |
+
)
|
| 1244 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1245 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1246 |
+
__tokenizer,
|
| 1247 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1248 |
+
)
|
| 1249 |
+
else:
|
| 1250 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1251 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1252 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1253 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1254 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1255 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1256 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1257 |
+
__tokenizer.tokenizer,
|
| 1258 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1259 |
+
)
|
| 1260 |
+
else:
|
| 1261 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1262 |
+
__tokenizer.tokenizer,
|
| 1263 |
+
mlm = False,
|
| 1264 |
+
mlm_probability = 0.0,
|
| 1265 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1266 |
+
)
|
| 1267 |
+
other_metrics = []
|
| 1268 |
+
|
| 1269 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1270 |
+
PatchRLStatistics('reward_trainer', other_metrics)
|
| 1271 |
+
|
| 1272 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1273 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1274 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1275 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1276 |
+
args._n_gpu = 1
|
| 1277 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1278 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1279 |
+
super().__init__(
|
| 1280 |
+
model = model,
|
| 1281 |
+
args = args,
|
| 1282 |
+
data_collator = data_collator,
|
| 1283 |
+
train_dataset = train_dataset,
|
| 1284 |
+
eval_dataset = eval_dataset,
|
| 1285 |
+
processing_class = processing_class,
|
| 1286 |
+
compute_metrics = compute_metrics,
|
| 1287 |
+
callbacks = callbacks,
|
| 1288 |
+
optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
|
| 1289 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1290 |
+
peft_config = peft_config,**kwargs)
|
| 1291 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1292 |
+
model.for_inference()
|
| 1293 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1294 |
+
self.neftune_hook_handle.remove()
|
| 1295 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1296 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1297 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1298 |
+
pass
|
| 1299 |
+
if hasattr(self, 'accelerator'):
|
| 1300 |
+
scaler = self.accelerator.scaler
|
| 1301 |
+
current_model = model
|
| 1302 |
+
while hasattr(current_model, 'model'):
|
| 1303 |
+
current_model.accelerator_scaler = scaler
|
| 1304 |
+
current_model = current_model.model
|
| 1305 |
+
current_model.accelerator_scaler = scaler
|
| 1306 |
+
pass
|
| 1307 |
+
if hasattr(self, 'train'):
|
| 1308 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1309 |
+
pass
|
| 1310 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1311 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1312 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1313 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1314 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1315 |
+
pass
|
| 1316 |
+
|
| 1317 |
+
pass
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
if hasattr(logger, "addFilter"):
|
| 1321 |
+
import logging
|
| 1322 |
+
class HideLoggingMessage(logging.Filter):
|
| 1323 |
+
def __init__(self, text): self.text = text
|
| 1324 |
+
def filter(self, x): return not (self.text in x.getMessage())
|
| 1325 |
+
pass
|
| 1326 |
+
logger.addFilter(HideLoggingMessage("`use_cache=True`"))
|
| 1327 |
+
|
unsloth_compiled_cache/UnslothSFTTrainer.py
ADDED
|
@@ -0,0 +1,1588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2026.3.2
|
| 3 |
+
2026.3.4
|
| 4 |
+
5.3.0
|
| 5 |
+
0.24.0
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth auto generated code
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from torch.nn import functional as F
|
| 29 |
+
from unsloth_zoo.temporary_patches.common import torch_compile
|
| 30 |
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 31 |
+
from trl.trainer.sft_trainer import (Any, AutoProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, dataclass, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pack_dataset, pad, selective_log_softmax, torch, Any, AutoProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, os, pad, torch, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_dataset, pad, PreTrainedModel, logger, os, torch, os)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
import math
|
| 36 |
+
import logging
|
| 37 |
+
from typing import *
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from packaging.version import Version
|
| 40 |
+
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from torch.nn import functional as F
|
| 44 |
+
import inspect
|
| 45 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 46 |
+
from transformers.training_args import ParallelMode
|
| 47 |
+
from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
|
| 48 |
+
|
| 49 |
+
# Wrap trainer with padding to right and enable training mode
|
| 50 |
+
# Also patches W&B since multiple runs must use wandb.finish()
|
| 51 |
+
import functools
|
| 52 |
+
from types import MethodType
|
| 53 |
+
try:
|
| 54 |
+
from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
|
| 55 |
+
except:
|
| 56 |
+
def reset_unsloth_gradient_checkpointing_buffers(): pass
|
| 57 |
+
def prepare_for_training_mode(f):
|
| 58 |
+
@functools.wraps(f)
|
| 59 |
+
def wrapper(self, *args, **kwargs):
|
| 60 |
+
# Enable training mode
|
| 61 |
+
_was_training = None
|
| 62 |
+
# Get gradient checkpointing setting from training arguments
|
| 63 |
+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
|
| 64 |
+
if hasattr(self, 'model') and hasattr(self.model, "training"):
|
| 65 |
+
_was_training = self.model.training
|
| 66 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
|
| 67 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 68 |
+
output = f(self, *args, **kwargs)
|
| 69 |
+
# Restore previous mode when possible
|
| 70 |
+
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
|
| 71 |
+
if _was_training is False:
|
| 72 |
+
self.model.for_inference()
|
| 73 |
+
elif _was_training is True and hasattr(self.model, "for_training"):
|
| 74 |
+
self.model.for_training(use_gradient_checkpointing=use_gc)
|
| 75 |
+
# Reset gradient checkpointing buffers to free memory while staying ready for next run
|
| 76 |
+
try:
|
| 77 |
+
reset_unsloth_gradient_checkpointing_buffers()
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
# Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
|
| 81 |
+
try:
|
| 82 |
+
import wandb
|
| 83 |
+
wandb.finish()
|
| 84 |
+
except:
|
| 85 |
+
pass
|
| 86 |
+
return output
|
| 87 |
+
return wrapper
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
torch_compile_options = {
|
| 91 |
+
"epilogue_fusion" : True,
|
| 92 |
+
"max_autotune" : False,
|
| 93 |
+
"shape_padding" : True,
|
| 94 |
+
"trace.enabled" : False,
|
| 95 |
+
"triton.cudagraphs" : False,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 99 |
+
def chunked_hidden_states_selective_log_softmax(
|
| 100 |
+
hidden_states: torch.Tensor,
|
| 101 |
+
lm_head: torch.Tensor,
|
| 102 |
+
index: torch.Tensor,
|
| 103 |
+
chunks: int = 4,
|
| 104 |
+
logit_scale_multiply: float = 0.0,
|
| 105 |
+
logit_scale_divide: float = 0.0,
|
| 106 |
+
logit_softcapping: float = 0.0,
|
| 107 |
+
temperature: float = 1.0,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
# All Unsloth Zoo code licensed under AGPL3
|
| 110 |
+
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
|
| 111 |
+
flat_index = index.reshape(-1)
|
| 112 |
+
|
| 113 |
+
chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
|
| 114 |
+
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
|
| 115 |
+
|
| 116 |
+
all_per_token_logps = []
|
| 117 |
+
|
| 118 |
+
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
|
| 119 |
+
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
|
| 120 |
+
|
| 121 |
+
if logit_scale_multiply != 0.0:
|
| 122 |
+
chunk_logits = chunk_logits * logit_scale_multiply
|
| 123 |
+
if logit_scale_divide != 0.0:
|
| 124 |
+
chunk_logits = chunk_logits / logit_scale_divide
|
| 125 |
+
if logit_softcapping != 0.0:
|
| 126 |
+
chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
|
| 127 |
+
|
| 128 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 129 |
+
|
| 130 |
+
if temperature != 1.0:
|
| 131 |
+
chunk_logits = chunk_logits / temperature
|
| 132 |
+
|
| 133 |
+
selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 134 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
|
| 135 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 136 |
+
all_per_token_logps.append(per_token_logps)
|
| 137 |
+
|
| 138 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 139 |
+
|
| 140 |
+
all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
|
| 141 |
+
return all_per_token_logps
|
| 142 |
+
|
| 143 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 144 |
+
def chunked_selective_log_softmax(logits, index):
|
| 145 |
+
# Split into 4 chunks only
|
| 146 |
+
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 147 |
+
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 148 |
+
all_per_token_logps = []
|
| 149 |
+
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 150 |
+
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 151 |
+
chunk_logits = chunk_logits.to(torch.float32)
|
| 152 |
+
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 153 |
+
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 154 |
+
per_token_logps = selected_logits - logsumexp_values
|
| 155 |
+
all_per_token_logps.append(per_token_logps)
|
| 156 |
+
pass
|
| 157 |
+
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 158 |
+
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 159 |
+
return all_per_token_logps
|
| 160 |
+
|
| 161 |
+
def calculate_pad_tokens_in_prompt(
|
| 162 |
+
input_ids: torch.Tensor,
|
| 163 |
+
logits_to_keep: int,
|
| 164 |
+
pad_token_id: int
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
|
| 168 |
+
"""
|
| 169 |
+
if logits_to_keep >= input_ids.shape[1]:
|
| 170 |
+
raise ValueError("logits_to_keep must be smaller than the sequence length.")
|
| 171 |
+
|
| 172 |
+
prompt_section = input_ids[:, :-logits_to_keep]
|
| 173 |
+
|
| 174 |
+
padding_mask = (prompt_section == pad_token_id)
|
| 175 |
+
|
| 176 |
+
pad_token_counts = padding_mask.sum(dim=1)
|
| 177 |
+
|
| 178 |
+
return pad_token_counts
|
| 179 |
+
|
| 180 |
+
def create_completion_attention_mask(
|
| 181 |
+
completion_input_ids: torch.Tensor,
|
| 182 |
+
left_pad_tokens_per_prompt: torch.Tensor,
|
| 183 |
+
max_left_pad: int,
|
| 184 |
+
pad_token_id: int
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
|
| 188 |
+
|
| 189 |
+
Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
|
| 190 |
+
and pad are pad tokens, this function would make a completion mask that would 0 out the pad
|
| 191 |
+
and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
|
| 192 |
+
"""
|
| 193 |
+
batch_size, completion_len = completion_input_ids.shape
|
| 194 |
+
device = completion_input_ids.device
|
| 195 |
+
|
| 196 |
+
num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
|
| 197 |
+
|
| 198 |
+
indices = torch.arange(completion_len, device=device).unsqueeze(0)
|
| 199 |
+
shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
non_padding_mask = (completion_input_ids != pad_token_id)
|
| 202 |
+
|
| 203 |
+
final_mask = shift_mask & non_padding_mask
|
| 204 |
+
|
| 205 |
+
return final_mask
|
| 206 |
+
|
| 207 |
+
def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Moves all padding tokens in each sequence of a batch to the right.
|
| 210 |
+
"""
|
| 211 |
+
mask = (tensor != pad_id)
|
| 212 |
+
# Must do stable=True since binary mark is unordered
|
| 213 |
+
sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
|
| 214 |
+
packed_tensor = torch.gather(tensor, 1, sorted_indices)
|
| 215 |
+
return packed_tensor
|
| 216 |
+
|
| 217 |
+
def align_logprobs_with_mask(
|
| 218 |
+
logprob_tensor: torch.Tensor,
|
| 219 |
+
attention_mask: torch.Tensor,
|
| 220 |
+
pad_value: float = 0.0
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Aligns a log probability tensor with a given attention mask.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
device = logprob_tensor.device
|
| 227 |
+
batch_size, logprob_seq_len = logprob_tensor.shape
|
| 228 |
+
mask_seq_len = attention_mask.shape[1]
|
| 229 |
+
|
| 230 |
+
padded_logprobs = torch.full(
|
| 231 |
+
attention_mask.shape,
|
| 232 |
+
fill_value=pad_value,
|
| 233 |
+
dtype=logprob_tensor.dtype,
|
| 234 |
+
device=device
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
left_pad_counts = torch.argmax(attention_mask, dim=1)
|
| 238 |
+
|
| 239 |
+
cols = torch.arange(logprob_seq_len, device=device)
|
| 240 |
+
dest_indices = left_pad_counts.unsqueeze(1) + cols
|
| 241 |
+
|
| 242 |
+
# Create destination row indices
|
| 243 |
+
# Shape: [batch_size, logprob_seq_len]
|
| 244 |
+
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
|
| 245 |
+
|
| 246 |
+
# --- 4. Filter out-of-bounds indices and perform assignment ---
|
| 247 |
+
# Create a mask to identify only the indices that are within the bounds
|
| 248 |
+
# of the target tensor's sequence length.
|
| 249 |
+
valid_mask = dest_indices < mask_seq_len
|
| 250 |
+
|
| 251 |
+
# Use this mask to select only the valid row indices, column indices,
|
| 252 |
+
# and the corresponding values from the logprob tensor.
|
| 253 |
+
# This flattens the selected elements into 1D tensors.
|
| 254 |
+
valid_rows = row_indices[valid_mask]
|
| 255 |
+
valid_cols = dest_indices[valid_mask]
|
| 256 |
+
valid_vals = logprob_tensor[valid_mask]
|
| 257 |
+
|
| 258 |
+
# Place the valid values into their correct positions in the padded tensor
|
| 259 |
+
# using a single, efficient advanced indexing operation.
|
| 260 |
+
padded_logprobs[valid_rows, valid_cols] = valid_vals
|
| 261 |
+
|
| 262 |
+
return padded_logprobs
|
| 263 |
+
|
| 264 |
+
def autotune_batch_and_chunks(
|
| 265 |
+
total_input_rows,
|
| 266 |
+
seq_len,
|
| 267 |
+
hidden_size,
|
| 268 |
+
vocab_size,
|
| 269 |
+
dtype_bytes=16,
|
| 270 |
+
multiplier=None
|
| 271 |
+
):
|
| 272 |
+
if multiplier is None:
|
| 273 |
+
final_m = max(4, seq_len // 4096)
|
| 274 |
+
else:
|
| 275 |
+
final_m = multiplier
|
| 276 |
+
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
free_bytes, _ = torch.cuda.mem_get_info()
|
| 279 |
+
limit_gb = (free_bytes / (1024**3))*.80
|
| 280 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 281 |
+
# For XPU: estimate free memory from total - reserved
|
| 282 |
+
total_mem = torch.xpu.get_device_properties(0).total_memory
|
| 283 |
+
reserved_mem = torch.xpu.memory_reserved()
|
| 284 |
+
free_bytes = total_mem - reserved_mem
|
| 285 |
+
limit_gb = (free_bytes / (1024**3)) * 0.80
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: assume 8GB available
|
| 288 |
+
limit_gb = 8.0
|
| 289 |
+
|
| 290 |
+
bytes_to_gb = 1024**3
|
| 291 |
+
|
| 292 |
+
b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
|
| 293 |
+
|
| 294 |
+
hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
|
| 295 |
+
|
| 296 |
+
base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
|
| 297 |
+
logits_gb = base_logits / final_m
|
| 298 |
+
|
| 299 |
+
total_mem_gb = hidden_gb + logits_gb
|
| 300 |
+
|
| 301 |
+
valid_mask = total_mem_gb <= limit_gb
|
| 302 |
+
valid_indices = torch.nonzero(valid_mask, as_tuple=False)
|
| 303 |
+
|
| 304 |
+
if valid_indices.shape[0] == 0:
|
| 305 |
+
#This means your GPU will OOM
|
| 306 |
+
return 4, final_m
|
| 307 |
+
|
| 308 |
+
best_idx = valid_indices[0].item()
|
| 309 |
+
final_b = int(b_vals[best_idx].item())
|
| 310 |
+
|
| 311 |
+
return final_b, final_m
|
| 312 |
+
|
| 313 |
+
def sanitize_logprob(logprob):
|
| 314 |
+
"""Local port of trl.scripts.vllm_serve.sanitize_logprob.
|
| 315 |
+
Filters NaN logprobs from vLLM outputs."""
|
| 316 |
+
value = logprob.logprob
|
| 317 |
+
if math.isnan(value):
|
| 318 |
+
logging.getLogger(__name__).warning(
|
| 319 |
+
f"Generated NaN logprob, token logprob '{logprob}' will be ignored"
|
| 320 |
+
)
|
| 321 |
+
return None
|
| 322 |
+
return value
|
| 323 |
+
@dataclass
|
| 324 |
+
class UnslothSFTConfig(SFTConfig):
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
Configuration class for the [`SFTTrainer`].
|
| 328 |
+
|
| 329 |
+
This class includes only the parameters that are specific to SFT training. For a full list of training arguments,
|
| 330 |
+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
| 331 |
+
differ from those in [`~transformers.TrainingArguments`].
|
| 332 |
+
|
| 333 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 334 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 335 |
+
command line.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
> Parameters that control the model
|
| 339 |
+
|
| 340 |
+
model_init_kwargs (`dict[str, Any]`, *optional*):
|
| 341 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
| 342 |
+
argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to
|
| 343 |
+
include the load balancing/auxilliary loss as a part of the final loss, remember to set
|
| 344 |
+
`output_router_logits=True` in this dictionary.
|
| 345 |
+
chat_template_path (`str`, *optional*):
|
| 346 |
+
If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
|
| 347 |
+
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
|
| 348 |
+
ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
|
| 349 |
+
embedding layer is resized accordingly.
|
| 350 |
+
|
| 351 |
+
> Parameters that control the data preprocessing
|
| 352 |
+
|
| 353 |
+
dataset_text_field (`str`, *optional*, defaults to `"text"`):
|
| 354 |
+
Name of the column that contains text data in the dataset.
|
| 355 |
+
dataset_kwargs (`dict[str, Any]`, *optional*):
|
| 356 |
+
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
|
| 357 |
+
`skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True`
|
| 358 |
+
regardless of the provided value, since preprocessing is done on the fly.
|
| 359 |
+
dataset_num_proc (`int`, *optional*):
|
| 360 |
+
Number of processes to use for processing the dataset.
|
| 361 |
+
eos_token (`str`, *optional*):
|
| 362 |
+
Token used to indicate the end of a turn or sequence. If `None`, it defaults to
|
| 363 |
+
`processing_class.eos_token`.
|
| 364 |
+
pad_token (`str`, *optional*):
|
| 365 |
+
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
|
| 366 |
+
it falls back to `processing_class.eos_token`.
|
| 367 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 368 |
+
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
|
| 369 |
+
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
| 370 |
+
packing (`bool`, *optional*, defaults to `False`):
|
| 371 |
+
Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce
|
| 372 |
+
padding. Uses `max_length` to define sequence length.
|
| 373 |
+
packing_strategy (`str`, *optional*, defaults to `"bfd"`):
|
| 374 |
+
Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`.
|
| 375 |
+
padding_free (`bool`, *optional*, defaults to `False`):
|
| 376 |
+
Whether to perform forward passes without padding by flattening all sequences in the batch into a single
|
| 377 |
+
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
|
| 378 |
+
supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When
|
| 379 |
+
packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this
|
| 380 |
+
parameter.
|
| 381 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 382 |
+
If set, the sequences will be padded to a multiple of this value.
|
| 383 |
+
eval_packing (`bool`, *optional*):
|
| 384 |
+
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
| 385 |
+
|
| 386 |
+
> Parameters that control the training
|
| 387 |
+
|
| 388 |
+
completion_only_loss (`bool`, *optional*):
|
| 389 |
+
Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed
|
| 390 |
+
only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If
|
| 391 |
+
`False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
|
| 392 |
+
loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full
|
| 393 |
+
sequence for [language modeling](#language-modeling) datasets.
|
| 394 |
+
assistant_only_loss (`bool`, *optional*, defaults to `False`):
|
| 395 |
+
Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only
|
| 396 |
+
on the assistant responses, which is supported only for [conversational](#conversational) datasets. If
|
| 397 |
+
`False`, loss is computed on the entire sequence.
|
| 398 |
+
loss_type (`str`, *optional*, defaults to `"nll"`):
|
| 399 |
+
Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic
|
| 400 |
+
Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)).
|
| 401 |
+
activation_offloading (`bool`, *optional*, defaults to `False`):
|
| 402 |
+
Whether to offload the activations to the CPU.
|
| 403 |
+
|
| 404 |
+
"""
|
| 405 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 406 |
+
default = None,
|
| 407 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 408 |
+
)
|
| 409 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 410 |
+
default = -1,
|
| 411 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 412 |
+
)
|
| 413 |
+
unsloth_logit_chunk_multiplier : Optional[int] = field(
|
| 414 |
+
default = None,
|
| 415 |
+
metadata = {'help': 'Multiplier for chunked logit computations.'},
|
| 416 |
+
)
|
| 417 |
+
unsloth_grpo_mini_batch : Optional[int] = field(
|
| 418 |
+
default = None,
|
| 419 |
+
metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
|
| 420 |
+
)
|
| 421 |
+
max_seq_length : Optional[int] = field(
|
| 422 |
+
default = None,
|
| 423 |
+
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
| 424 |
+
)
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
output_dir = None,
|
| 428 |
+
per_device_train_batch_size = 4,
|
| 429 |
+
num_train_epochs = 3.0,
|
| 430 |
+
max_steps = -1,
|
| 431 |
+
learning_rate = 5e-05,
|
| 432 |
+
lr_scheduler_type = 'linear',
|
| 433 |
+
lr_scheduler_kwargs = None,
|
| 434 |
+
warmup_steps = 0.1,
|
| 435 |
+
optim = 'adamw_8bit',
|
| 436 |
+
optim_args = None,
|
| 437 |
+
weight_decay = 0.01,
|
| 438 |
+
adam_beta1 = 0.9,
|
| 439 |
+
adam_beta2 = 0.999,
|
| 440 |
+
adam_epsilon = 1e-08,
|
| 441 |
+
optim_target_modules = None,
|
| 442 |
+
gradient_accumulation_steps = 2,
|
| 443 |
+
average_tokens_across_devices = True,
|
| 444 |
+
max_grad_norm = 1.0,
|
| 445 |
+
label_smoothing_factor = 0.0,
|
| 446 |
+
bf16 = False,
|
| 447 |
+
fp16 = False,
|
| 448 |
+
bf16_full_eval = False,
|
| 449 |
+
fp16_full_eval = False,
|
| 450 |
+
tf32 = None,
|
| 451 |
+
gradient_checkpointing = True,
|
| 452 |
+
gradient_checkpointing_kwargs = None,
|
| 453 |
+
torch_compile = False,
|
| 454 |
+
torch_compile_backend = None,
|
| 455 |
+
torch_compile_mode = None,
|
| 456 |
+
use_liger_kernel = False,
|
| 457 |
+
liger_kernel_config = None,
|
| 458 |
+
use_cache = False,
|
| 459 |
+
neftune_noise_alpha = None,
|
| 460 |
+
torch_empty_cache_steps = 250,
|
| 461 |
+
auto_find_batch_size = False,
|
| 462 |
+
logging_strategy = 'steps',
|
| 463 |
+
logging_steps = 1,
|
| 464 |
+
logging_first_step = False,
|
| 465 |
+
log_on_each_node = True,
|
| 466 |
+
logging_nan_inf_filter = False,
|
| 467 |
+
include_num_input_tokens_seen = False,
|
| 468 |
+
log_level = 'passive',
|
| 469 |
+
log_level_replica = 'warning',
|
| 470 |
+
disable_tqdm = None,
|
| 471 |
+
report_to = 'none',
|
| 472 |
+
run_name = None,
|
| 473 |
+
project = 'huggingface',
|
| 474 |
+
trackio_space_id = 'trackio',
|
| 475 |
+
eval_strategy = 'no',
|
| 476 |
+
eval_steps = None,
|
| 477 |
+
eval_delay = 0,
|
| 478 |
+
per_device_eval_batch_size = 4,
|
| 479 |
+
prediction_loss_only = False,
|
| 480 |
+
eval_on_start = False,
|
| 481 |
+
eval_do_concat_batches = True,
|
| 482 |
+
eval_use_gather_object = False,
|
| 483 |
+
eval_accumulation_steps = 2,
|
| 484 |
+
batch_eval_metrics = False,
|
| 485 |
+
save_only_model = False,
|
| 486 |
+
save_strategy = 'steps',
|
| 487 |
+
save_steps = 500,
|
| 488 |
+
save_on_each_node = False,
|
| 489 |
+
save_total_limit = None,
|
| 490 |
+
enable_jit_checkpoint = False,
|
| 491 |
+
push_to_hub = False,
|
| 492 |
+
hub_token = None,
|
| 493 |
+
hub_private_repo = None,
|
| 494 |
+
hub_model_id = None,
|
| 495 |
+
hub_strategy = 'every_save',
|
| 496 |
+
hub_always_push = False,
|
| 497 |
+
hub_revision = None,
|
| 498 |
+
load_best_model_at_end = False,
|
| 499 |
+
metric_for_best_model = None,
|
| 500 |
+
greater_is_better = None,
|
| 501 |
+
ignore_data_skip = False,
|
| 502 |
+
restore_callback_states_from_checkpoint = False,
|
| 503 |
+
full_determinism = False,
|
| 504 |
+
seed = 3407,
|
| 505 |
+
data_seed = 3407,
|
| 506 |
+
use_cpu = False,
|
| 507 |
+
accelerator_config = None,
|
| 508 |
+
parallelism_config = None,
|
| 509 |
+
dataloader_drop_last = False,
|
| 510 |
+
dataloader_num_workers = 0,
|
| 511 |
+
dataloader_pin_memory = True,
|
| 512 |
+
dataloader_persistent_workers = False,
|
| 513 |
+
dataloader_prefetch_factor = None,
|
| 514 |
+
remove_unused_columns = True,
|
| 515 |
+
label_names = None,
|
| 516 |
+
train_sampling_strategy = 'random',
|
| 517 |
+
length_column_name = 'length',
|
| 518 |
+
ddp_find_unused_parameters = None,
|
| 519 |
+
ddp_bucket_cap_mb = None,
|
| 520 |
+
ddp_broadcast_buffers = None,
|
| 521 |
+
ddp_backend = None,
|
| 522 |
+
ddp_timeout = 1800,
|
| 523 |
+
fsdp = None,
|
| 524 |
+
fsdp_config = None,
|
| 525 |
+
deepspeed = None,
|
| 526 |
+
debug = '',
|
| 527 |
+
skip_memory_metrics = True,
|
| 528 |
+
do_train = False,
|
| 529 |
+
do_eval = False,
|
| 530 |
+
do_predict = False,
|
| 531 |
+
resume_from_checkpoint = None,
|
| 532 |
+
warmup_ratio = None,
|
| 533 |
+
logging_dir = None,
|
| 534 |
+
local_rank = -1,
|
| 535 |
+
model_init_kwargs = None,
|
| 536 |
+
chat_template_path = None,
|
| 537 |
+
dataset_text_field = 'text',
|
| 538 |
+
dataset_kwargs = None,
|
| 539 |
+
dataset_num_proc = None,
|
| 540 |
+
eos_token = None,
|
| 541 |
+
pad_token = None,
|
| 542 |
+
max_length = 1024,
|
| 543 |
+
packing = False,
|
| 544 |
+
packing_strategy = 'bfd',
|
| 545 |
+
padding_free = None,
|
| 546 |
+
pad_to_multiple_of = None,
|
| 547 |
+
eval_packing = None,
|
| 548 |
+
completion_only_loss = None,
|
| 549 |
+
assistant_only_loss = False,
|
| 550 |
+
loss_type = 'nll',
|
| 551 |
+
activation_offloading = False,
|
| 552 |
+
vllm_sampling_params = None,
|
| 553 |
+
unsloth_num_chunks = -1,
|
| 554 |
+
unsloth_logit_chunk_multiplier = None,
|
| 555 |
+
unsloth_grpo_mini_batch = None,
|
| 556 |
+
max_seq_length = None,
|
| 557 |
+
**kwargs,
|
| 558 |
+
):
|
| 559 |
+
if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 560 |
+
if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 561 |
+
if num_train_epochs is None:
|
| 562 |
+
num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
|
| 563 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 564 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 565 |
+
save_strategy = 'no'
|
| 566 |
+
import multiprocessing as _mp
|
| 567 |
+
if _mp.get_start_method() != 'fork':
|
| 568 |
+
dataset_num_proc = None
|
| 569 |
+
elif dataset_num_proc is None:
|
| 570 |
+
import psutil
|
| 571 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 572 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 573 |
+
if memory_gb_left <= 2: dataset_num_proc = 1
|
| 574 |
+
else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 575 |
+
if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
|
| 576 |
+
from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
|
| 577 |
+
if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
|
| 578 |
+
from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
|
| 579 |
+
pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
super().__init__(
|
| 583 |
+
output_dir = output_dir,
|
| 584 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 585 |
+
num_train_epochs = num_train_epochs,
|
| 586 |
+
max_steps = max_steps,
|
| 587 |
+
learning_rate = learning_rate,
|
| 588 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 589 |
+
lr_scheduler_kwargs = lr_scheduler_kwargs,
|
| 590 |
+
warmup_steps = warmup_steps,
|
| 591 |
+
optim = optim,
|
| 592 |
+
optim_args = optim_args,
|
| 593 |
+
weight_decay = weight_decay,
|
| 594 |
+
adam_beta1 = adam_beta1,
|
| 595 |
+
adam_beta2 = adam_beta2,
|
| 596 |
+
adam_epsilon = adam_epsilon,
|
| 597 |
+
optim_target_modules = optim_target_modules,
|
| 598 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 599 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 600 |
+
max_grad_norm = max_grad_norm,
|
| 601 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 602 |
+
bf16 = bf16,
|
| 603 |
+
fp16 = fp16,
|
| 604 |
+
bf16_full_eval = bf16_full_eval,
|
| 605 |
+
fp16_full_eval = fp16_full_eval,
|
| 606 |
+
tf32 = tf32,
|
| 607 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 608 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 609 |
+
torch_compile = torch_compile,
|
| 610 |
+
torch_compile_backend = torch_compile_backend,
|
| 611 |
+
torch_compile_mode = torch_compile_mode,
|
| 612 |
+
use_liger_kernel = use_liger_kernel,
|
| 613 |
+
liger_kernel_config = liger_kernel_config,
|
| 614 |
+
use_cache = use_cache,
|
| 615 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 616 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 617 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 618 |
+
logging_strategy = logging_strategy,
|
| 619 |
+
logging_steps = logging_steps,
|
| 620 |
+
logging_first_step = logging_first_step,
|
| 621 |
+
log_on_each_node = log_on_each_node,
|
| 622 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 623 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 624 |
+
log_level = log_level,
|
| 625 |
+
log_level_replica = log_level_replica,
|
| 626 |
+
disable_tqdm = disable_tqdm,
|
| 627 |
+
report_to = report_to,
|
| 628 |
+
run_name = run_name,
|
| 629 |
+
project = project,
|
| 630 |
+
trackio_space_id = trackio_space_id,
|
| 631 |
+
eval_strategy = eval_strategy,
|
| 632 |
+
eval_steps = eval_steps,
|
| 633 |
+
eval_delay = eval_delay,
|
| 634 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 635 |
+
prediction_loss_only = prediction_loss_only,
|
| 636 |
+
eval_on_start = eval_on_start,
|
| 637 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 638 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 639 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 640 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 641 |
+
save_only_model = save_only_model,
|
| 642 |
+
save_strategy = save_strategy,
|
| 643 |
+
save_steps = save_steps,
|
| 644 |
+
save_on_each_node = save_on_each_node,
|
| 645 |
+
save_total_limit = save_total_limit,
|
| 646 |
+
enable_jit_checkpoint = enable_jit_checkpoint,
|
| 647 |
+
push_to_hub = push_to_hub,
|
| 648 |
+
hub_token = hub_token,
|
| 649 |
+
hub_private_repo = hub_private_repo,
|
| 650 |
+
hub_model_id = hub_model_id,
|
| 651 |
+
hub_strategy = hub_strategy,
|
| 652 |
+
hub_always_push = hub_always_push,
|
| 653 |
+
hub_revision = hub_revision,
|
| 654 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 655 |
+
metric_for_best_model = metric_for_best_model,
|
| 656 |
+
greater_is_better = greater_is_better,
|
| 657 |
+
ignore_data_skip = ignore_data_skip,
|
| 658 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 659 |
+
full_determinism = full_determinism,
|
| 660 |
+
seed = seed,
|
| 661 |
+
data_seed = data_seed,
|
| 662 |
+
use_cpu = use_cpu,
|
| 663 |
+
accelerator_config = accelerator_config,
|
| 664 |
+
parallelism_config = parallelism_config,
|
| 665 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 666 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 667 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 668 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 669 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 670 |
+
remove_unused_columns = remove_unused_columns,
|
| 671 |
+
label_names = label_names,
|
| 672 |
+
train_sampling_strategy = train_sampling_strategy,
|
| 673 |
+
length_column_name = length_column_name,
|
| 674 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 675 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 676 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 677 |
+
ddp_backend = ddp_backend,
|
| 678 |
+
ddp_timeout = ddp_timeout,
|
| 679 |
+
fsdp = fsdp,
|
| 680 |
+
fsdp_config = fsdp_config,
|
| 681 |
+
deepspeed = deepspeed,
|
| 682 |
+
debug = debug,
|
| 683 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 684 |
+
do_train = do_train,
|
| 685 |
+
do_eval = do_eval,
|
| 686 |
+
do_predict = do_predict,
|
| 687 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 688 |
+
warmup_ratio = warmup_ratio,
|
| 689 |
+
logging_dir = logging_dir,
|
| 690 |
+
local_rank = local_rank,
|
| 691 |
+
model_init_kwargs = model_init_kwargs,
|
| 692 |
+
chat_template_path = chat_template_path,
|
| 693 |
+
dataset_text_field = dataset_text_field,
|
| 694 |
+
dataset_kwargs = dataset_kwargs,
|
| 695 |
+
dataset_num_proc = dataset_num_proc,
|
| 696 |
+
eos_token = eos_token,
|
| 697 |
+
pad_token = pad_token,
|
| 698 |
+
max_length = max_length,
|
| 699 |
+
packing = packing,
|
| 700 |
+
packing_strategy = packing_strategy,
|
| 701 |
+
padding_free = padding_free,
|
| 702 |
+
pad_to_multiple_of = pad_to_multiple_of,
|
| 703 |
+
eval_packing = eval_packing,
|
| 704 |
+
completion_only_loss = completion_only_loss,
|
| 705 |
+
assistant_only_loss = assistant_only_loss,
|
| 706 |
+
loss_type = loss_type,
|
| 707 |
+
activation_offloading = activation_offloading,**kwargs)
|
| 708 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 709 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 710 |
+
if unsloth_grpo_mini_batch is not None:
|
| 711 |
+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
|
| 712 |
+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
|
| 713 |
+
else:
|
| 714 |
+
raise ValueError(
|
| 715 |
+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
|
| 716 |
+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
|
| 717 |
+
)
|
| 718 |
+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
|
| 719 |
+
self.max_seq_length = max_seq_length
|
| 720 |
+
|
| 721 |
+
pass
|
| 722 |
+
|
| 723 |
+
class _UnslothSFTTrainer(BaseTrainer):
|
| 724 |
+
""""""
|
| 725 |
+
|
| 726 |
+
_tag_names = ["trl", "sft"]
|
| 727 |
+
_name = "SFT"
|
| 728 |
+
|
| 729 |
+
def __init__(
|
| 730 |
+
self,
|
| 731 |
+
model: Union[str, PreTrainedModel],
|
| 732 |
+
args: Optional[Union[SFTConfig, TrainingArguments]] = None,
|
| 733 |
+
data_collator: Optional[DataCollator] = None,
|
| 734 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 735 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 736 |
+
processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
|
| 737 |
+
compute_loss_func: Optional[Callable] = None,
|
| 738 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 739 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 740 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 741 |
+
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
| 742 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 743 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 744 |
+
formatting_func: Optional[Callable[[dict], str]] = None,
|
| 745 |
+
):
|
| 746 |
+
# Args
|
| 747 |
+
if args is None:
|
| 748 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 749 |
+
model_name = model_name.split("/")[-1]
|
| 750 |
+
args = SFTConfig(f"{model_name}-SFT")
|
| 751 |
+
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
|
| 752 |
+
dict_args = args.to_dict()
|
| 753 |
+
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
| 754 |
+
dict_args.pop("push_to_hub_token", None)
|
| 755 |
+
args = SFTConfig(**dict_args)
|
| 756 |
+
|
| 757 |
+
# Model
|
| 758 |
+
if isinstance(model, str):
|
| 759 |
+
model = create_model_from_path(model, **args.model_init_kwargs or {})
|
| 760 |
+
else:
|
| 761 |
+
if args.model_init_kwargs is not None:
|
| 762 |
+
logger.warning(
|
| 763 |
+
"You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. "
|
| 764 |
+
"The `model_init_kwargs` will be ignored."
|
| 765 |
+
)
|
| 766 |
+
model_id = model.config._name_or_path
|
| 767 |
+
|
| 768 |
+
# Processing class
|
| 769 |
+
if processing_class is None:
|
| 770 |
+
processing_class = AutoProcessor.from_pretrained(model_id)
|
| 771 |
+
|
| 772 |
+
# Handle pad token for processors or tokenizers
|
| 773 |
+
if isinstance(processing_class, ProcessorMixin):
|
| 774 |
+
tokenizer = processing_class.tokenizer
|
| 775 |
+
self._is_vlm = True
|
| 776 |
+
elif isinstance(processing_class, PreTrainedTokenizerBase):
|
| 777 |
+
tokenizer = processing_class
|
| 778 |
+
self._is_vlm = False
|
| 779 |
+
else:
|
| 780 |
+
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
|
| 781 |
+
|
| 782 |
+
if args.eos_token is not None:
|
| 783 |
+
eos_token = args.eos_token
|
| 784 |
+
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
|
| 785 |
+
if eos_token_id is None:
|
| 786 |
+
raise ValueError(
|
| 787 |
+
f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
|
| 788 |
+
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
|
| 789 |
+
"in the vocabulary before using it as an EOS token."
|
| 790 |
+
)
|
| 791 |
+
tokenizer.eos_token_id = eos_token_id
|
| 792 |
+
|
| 793 |
+
if args.chat_template_path is not None:
|
| 794 |
+
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
|
| 795 |
+
with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
|
| 796 |
+
processing_class.chat_template = chat_template_file.read()
|
| 797 |
+
added_tokens = []
|
| 798 |
+
else:
|
| 799 |
+
model, processing_class, added_tokens = clone_chat_template(
|
| 800 |
+
model, processing_class, args.chat_template_path
|
| 801 |
+
)
|
| 802 |
+
else:
|
| 803 |
+
added_tokens = []
|
| 804 |
+
|
| 805 |
+
# Catch some wrong configurations related to VLMs
|
| 806 |
+
if self._is_vlm and args.packing:
|
| 807 |
+
raise ValueError(
|
| 808 |
+
"Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig."
|
| 809 |
+
)
|
| 810 |
+
if self._is_vlm and args.padding_free:
|
| 811 |
+
raise ValueError(
|
| 812 |
+
"Padding-free training is yet not supported for vision-language models. Please set "
|
| 813 |
+
"`padding_free=False` in the `SFTConfig`."
|
| 814 |
+
)
|
| 815 |
+
if self._is_vlm and args.assistant_only_loss:
|
| 816 |
+
raise ValueError(
|
| 817 |
+
"Assistant-only loss is not yet supported for vision-language models. Please set "
|
| 818 |
+
"`assistant_only_loss=False` in the `SFTConfig`."
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
# PEFT configuration and model wrapping
|
| 822 |
+
if False:
|
| 823 |
+
if added_tokens:
|
| 824 |
+
# Ensure that the added tokens are trainable
|
| 825 |
+
if peft_config.trainable_token_indices is None:
|
| 826 |
+
peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
|
| 827 |
+
elif "embed_tokens" not in peft_config.trainable_token_indices:
|
| 828 |
+
peft_config.trainable_token_indices["embed_tokens"] = added_tokens
|
| 829 |
+
else:
|
| 830 |
+
peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
|
| 831 |
+
|
| 832 |
+
# Ensure that the lm_head is trainable
|
| 833 |
+
if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
|
| 834 |
+
logger.warning(
|
| 835 |
+
"Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
|
| 836 |
+
"`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
|
| 837 |
+
"tokens, leading to degraded generation quality. To fix this, add "
|
| 838 |
+
"`modules_to_save=['lm_head']` to your PEFT configuration."
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
if peft_config.modules_to_save is None:
|
| 842 |
+
peft_config.modules_to_save = ["lm_head"]
|
| 843 |
+
else:
|
| 844 |
+
peft_config.modules_to_save.append("lm_head")
|
| 845 |
+
|
| 846 |
+
# In Prompt Tuning a small set of trainable virtual tokens [continuous prompt embeddings] is prepended to the
|
| 847 |
+
# input. We store the number of these tokens so we can account for them correctly when calculating accuracy.
|
| 848 |
+
self.num_virtual_tokens = 0
|
| 849 |
+
|
| 850 |
+
if False:
|
| 851 |
+
pass
|
| 852 |
+
if model.active_adapter in model.peft_config:
|
| 853 |
+
peft_model_config = model.peft_config[model.active_adapter]
|
| 854 |
+
self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0)
|
| 855 |
+
|
| 856 |
+
# Data collator
|
| 857 |
+
# BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing
|
| 858 |
+
# FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask.
|
| 859 |
+
self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd")
|
| 860 |
+
use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
|
| 861 |
+
if self.padding_free:
|
| 862 |
+
if data_collator is not None:
|
| 863 |
+
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
|
| 864 |
+
if args.packing and args.packing_strategy == "wrapped":
|
| 865 |
+
logger.warning(
|
| 866 |
+
"You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not "
|
| 867 |
+
"recommended. Please refer to the documentation to understand why this is not recommended."
|
| 868 |
+
)
|
| 869 |
+
if not use_flash_attention:
|
| 870 |
+
logger.warning(
|
| 871 |
+
"Padding-free training is enabled, but the attention implementation is not set to a supported "
|
| 872 |
+
"flash attention variant. Padding-free training flattens batches into a single sequence, and only "
|
| 873 |
+
"the following implementations are known to reliably support this: "
|
| 874 |
+
f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to "
|
| 875 |
+
"unexpected behavior. To ensure compatibility, set `attn_implementation` in the model "
|
| 876 |
+
"configuration to one of these supported options or verify that your attention mechanism can "
|
| 877 |
+
"handle flattened sequences."
|
| 878 |
+
)
|
| 879 |
+
# Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format
|
| 880 |
+
# is prompt-completion, and False if the dataset format is language modeling.
|
| 881 |
+
dataset_sample = next(iter(train_dataset))
|
| 882 |
+
if args.completion_only_loss is None:
|
| 883 |
+
self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample
|
| 884 |
+
else:
|
| 885 |
+
self.completion_only_loss = args.completion_only_loss
|
| 886 |
+
|
| 887 |
+
self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
|
| 888 |
+
# Unsloth: override _is_vlm for VLM models that pass a bare tokenizer
|
| 889 |
+
if not self._is_vlm and self._is_vision_dataset:
|
| 890 |
+
_m = model
|
| 891 |
+
if hasattr(_m, "model"): _m = _m.model
|
| 892 |
+
if hasattr(getattr(_m, "config", None), "vision_config") or\
|
| 893 |
+
_m.__class__.__name__.endswith("ForConditionalGeneration"):
|
| 894 |
+
self._is_vlm = True
|
| 895 |
+
if self._is_vision_dataset and not self._is_vlm:
|
| 896 |
+
raise ValueError(
|
| 897 |
+
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
|
| 898 |
+
"model does not seem to be a vision-language model. Please check your model and dataset."
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
if data_collator is None and not self._is_vision_dataset:
|
| 902 |
+
# Get the pad token: if not provided, use the one from the processing class or the eos token
|
| 903 |
+
# if the processing class does not have a pad token.
|
| 904 |
+
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
|
| 905 |
+
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
|
| 906 |
+
if pad_token_id is None:
|
| 907 |
+
raise ValueError(
|
| 908 |
+
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
|
| 909 |
+
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
| 910 |
+
"in the vocabulary before using it as a padding token."
|
| 911 |
+
)
|
| 912 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 913 |
+
pad_token_id=pad_token_id,
|
| 914 |
+
completion_only_loss=self.completion_only_loss,
|
| 915 |
+
padding_free=self.padding_free,
|
| 916 |
+
pad_to_multiple_of=args.pad_to_multiple_of,
|
| 917 |
+
)
|
| 918 |
+
elif data_collator is None and self._is_vision_dataset:
|
| 919 |
+
data_collator = DataCollatorForVisionLanguageModeling(
|
| 920 |
+
processor=processing_class,
|
| 921 |
+
max_length=args.max_length,
|
| 922 |
+
completion_only_loss=self.completion_only_loss,
|
| 923 |
+
pad_to_multiple_of=args.pad_to_multiple_of,
|
| 924 |
+
dataset_text_field=args.dataset_text_field,
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
|
| 928 |
+
logger.warning(
|
| 929 |
+
"You are using packing, but the attention implementation is not set to a supported flash attention "
|
| 930 |
+
"variant. Packing gathers multiple samples into a single sequence, and only the following "
|
| 931 |
+
f"implementations are known to reliably support this: {', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. "
|
| 932 |
+
"Using other implementations may lead to cross-contamination between samples. To avoid this, either "
|
| 933 |
+
"disable packing by setting `packing=False`, or set `attn_implementation` in the model configuration "
|
| 934 |
+
"to one of these supported options."
|
| 935 |
+
)
|
| 936 |
+
if args.assistant_only_loss and not is_conversational(dataset_sample):
|
| 937 |
+
raise ValueError(
|
| 938 |
+
"You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only "
|
| 939 |
+
"supported for conversational datasets."
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
# Dataset
|
| 943 |
+
# Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
|
| 944 |
+
# preprocessing [e.g., image-to-pixel conversion] is too costly and done on the fly instead.
|
| 945 |
+
skip_prepare_dataset = (
|
| 946 |
+
args.dataset_kwargs is not None
|
| 947 |
+
and args.dataset_kwargs.get("skip_prepare_dataset", False)
|
| 948 |
+
or self._is_vision_dataset
|
| 949 |
+
)
|
| 950 |
+
if not skip_prepare_dataset:
|
| 951 |
+
if self.completion_only_loss and formatting_func:
|
| 952 |
+
raise ValueError(
|
| 953 |
+
"A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
|
| 954 |
+
"Using a formatter converts the dataset to a language modeling type, conflicting with "
|
| 955 |
+
"completion-only loss. To resolve this, apply your formatting function before passing the "
|
| 956 |
+
"dataset, or disable `completion_only_loss` in `SFTConfig`."
|
| 957 |
+
)
|
| 958 |
+
self._unsloth_model_ref = model
|
| 959 |
+
train_dataset = self._prepare_dataset(
|
| 960 |
+
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
| 961 |
+
)
|
| 962 |
+
if eval_dataset is not None:
|
| 963 |
+
packing = args.packing if args.eval_packing is None else args.eval_packing
|
| 964 |
+
if isinstance(eval_dataset, dict):
|
| 965 |
+
eval_dataset = {
|
| 966 |
+
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
| 967 |
+
for key, dataset in eval_dataset.items()
|
| 968 |
+
}
|
| 969 |
+
else:
|
| 970 |
+
eval_dataset = self._prepare_dataset(
|
| 971 |
+
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
# Loss function
|
| 975 |
+
if args.loss_type == "nll":
|
| 976 |
+
pass # use the default loss
|
| 977 |
+
elif args.loss_type == "dft":
|
| 978 |
+
if compute_loss_func is not None:
|
| 979 |
+
raise ValueError(
|
| 980 |
+
"You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
|
| 981 |
+
"When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a "
|
| 982 |
+
"`compute_loss_func` is not allowed."
|
| 983 |
+
)
|
| 984 |
+
compute_loss_func = dft_loss
|
| 985 |
+
else:
|
| 986 |
+
raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")
|
| 987 |
+
|
| 988 |
+
# Initialize the metrics
|
| 989 |
+
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
| 990 |
+
self._total_train_tokens = 0
|
| 991 |
+
|
| 992 |
+
# Initialize the Trainer. Parent class will handle:
|
| 993 |
+
# - DeepSpeed configuration [through create_accelerator_and_postprocess]
|
| 994 |
+
# - FSDP setup
|
| 995 |
+
# - Distributed training setup
|
| 996 |
+
# - Optimizer and scheduler creation
|
| 997 |
+
|
| 998 |
+
super().__init__(
|
| 999 |
+
model=model,
|
| 1000 |
+
args=args,
|
| 1001 |
+
data_collator=data_collator,
|
| 1002 |
+
train_dataset=train_dataset,
|
| 1003 |
+
eval_dataset=eval_dataset,
|
| 1004 |
+
processing_class=processing_class,
|
| 1005 |
+
compute_loss_func=compute_loss_func,
|
| 1006 |
+
compute_metrics=compute_metrics,
|
| 1007 |
+
callbacks=callbacks,
|
| 1008 |
+
optimizers=optimizers,
|
| 1009 |
+
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
|
| 1010 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
# Initialize activation offloading context
|
| 1014 |
+
if self.args.activation_offloading:
|
| 1015 |
+
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
|
| 1016 |
+
else:
|
| 1017 |
+
self.maybe_activation_offload_context = contextlib.nullcontext()
|
| 1018 |
+
|
| 1019 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 1020 |
+
if hasattr(self.model, "add_model_tags"):
|
| 1021 |
+
self.model.add_model_tags(self._tag_names)
|
| 1022 |
+
|
| 1023 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 1024 |
+
|
| 1025 |
+
def _prepare_dataset(
|
| 1026 |
+
self,
|
| 1027 |
+
dataset: Union[Dataset, IterableDataset],
|
| 1028 |
+
processing_class,
|
| 1029 |
+
args,
|
| 1030 |
+
packing: bool,
|
| 1031 |
+
formatting_func: Optional[Callable[[dict], str]],
|
| 1032 |
+
dataset_name: str,
|
| 1033 |
+
) -> Union[Dataset, IterableDataset]:
|
| 1034 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
| 1035 |
+
try:
|
| 1036 |
+
if isinstance(dataset, ConstantLengthDataset): return dataset
|
| 1037 |
+
except:
|
| 1038 |
+
pass
|
| 1039 |
+
|
| 1040 |
+
map_kwargs = {}
|
| 1041 |
+
use_desc = isinstance(dataset, Dataset)
|
| 1042 |
+
is_vlm = hasattr(processing_class, "tokenizer")
|
| 1043 |
+
tokenizer = processing_class
|
| 1044 |
+
if is_vlm: tokenizer = processing_class.tokenizer
|
| 1045 |
+
|
| 1046 |
+
# Dynamic detection: check if model's module defines a function
|
| 1047 |
+
# that requires token_type_ids when is_training=True
|
| 1048 |
+
import sys as _sys
|
| 1049 |
+
_needs_token_type_ids = False
|
| 1050 |
+
# Split to avoid compiler substring match on masking_utils names
|
| 1051 |
+
_ccm = 'create_' + 'causal_mask_mapping'
|
| 1052 |
+
_model = getattr(self, '_unsloth_model_ref', None) or getattr(self, 'model', None)
|
| 1053 |
+
if _model is not None:
|
| 1054 |
+
for _m in (_model, getattr(_model, 'model', None)):
|
| 1055 |
+
if _m is None: continue
|
| 1056 |
+
_mod = _sys.modules.get(type(_m).__module__)
|
| 1057 |
+
if _mod is not None and hasattr(_mod, _ccm):
|
| 1058 |
+
_needs_token_type_ids = True
|
| 1059 |
+
break
|
| 1060 |
+
|
| 1061 |
+
if not _needs_token_type_ids:
|
| 1062 |
+
# Fallback: model not yet available, check processor class MRO
|
| 1063 |
+
for _base in type(processing_class).__mro__:
|
| 1064 |
+
_base_mod = getattr(_base, '__module__', '')
|
| 1065 |
+
if 'transformers.models.' in _base_mod:
|
| 1066 |
+
_modeling_mod = _base_mod.replace('.processing_', '.modeling_')
|
| 1067 |
+
_mod = _sys.modules.get(_modeling_mod)
|
| 1068 |
+
if _mod is not None and hasattr(_mod, _ccm):
|
| 1069 |
+
_needs_token_type_ids = True
|
| 1070 |
+
break
|
| 1071 |
+
if _needs_token_type_ids and hasattr(args, 'remove_unused_columns'):
|
| 1072 |
+
args.remove_unused_columns = False
|
| 1073 |
+
|
| 1074 |
+
# Get max length
|
| 1075 |
+
max_seq_length = getattr(args, "max_length", 0)
|
| 1076 |
+
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
|
| 1077 |
+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
|
| 1078 |
+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
|
| 1079 |
+
if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
|
| 1080 |
+
dataset_text_field = getattr(args, "dataset_text_field", "text")
|
| 1081 |
+
do_truncation = max_seq_length != 0
|
| 1082 |
+
do_formatting_func = False
|
| 1083 |
+
do_tokenize = True
|
| 1084 |
+
|
| 1085 |
+
# Get correct column names
|
| 1086 |
+
column_names = set(next(iter(dataset)).keys())
|
| 1087 |
+
used_column_names = ["input_ids"]
|
| 1088 |
+
if "attention_mask" in column_names:
|
| 1089 |
+
used_column_names.append("attention_mask")
|
| 1090 |
+
if _needs_token_type_ids:
|
| 1091 |
+
used_column_names.append("token_type_ids")
|
| 1092 |
+
|
| 1093 |
+
# Check if already tokenized so skip
|
| 1094 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 1095 |
+
if "labels" in column_names:
|
| 1096 |
+
# Most likely forgot data collator!
|
| 1097 |
+
if is_vlm and not hasattr(tokenizer, "pad"):
|
| 1098 |
+
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
| 1099 |
+
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
| 1100 |
+
self.data_collator = DataCollatorForSeq2Seq(tokenizer)
|
| 1101 |
+
used_column_names.append("labels")
|
| 1102 |
+
do_tokenize = False
|
| 1103 |
+
elif "input_ids" in column_names:
|
| 1104 |
+
# Skip dataset prep, and set data collator
|
| 1105 |
+
if is_vlm and not hasattr(tokenizer, "pad"):
|
| 1106 |
+
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
| 1107 |
+
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
| 1108 |
+
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
| 1109 |
+
do_tokenize = False
|
| 1110 |
+
elif dataset_text_field not in column_names:
|
| 1111 |
+
do_formatting_func = True
|
| 1112 |
+
if formatting_func is None:
|
| 1113 |
+
raise RuntimeError("Unsloth: You must specify a `formatting_func`")
|
| 1114 |
+
pass
|
| 1115 |
+
|
| 1116 |
+
if do_tokenize:
|
| 1117 |
+
# Check double BOS tokens
|
| 1118 |
+
if do_formatting_func:
|
| 1119 |
+
test_text = formatting_func(next(iter(dataset)))
|
| 1120 |
+
if not isinstance(test_text, list):
|
| 1121 |
+
raise ValueError(
|
| 1122 |
+
"Unsloth: The `formatting_func` should return a list of processed strings."
|
| 1123 |
+
)
|
| 1124 |
+
test_text = test_text[0]
|
| 1125 |
+
else:
|
| 1126 |
+
test_text = next(iter(dataset))[dataset_text_field][0]
|
| 1127 |
+
|
| 1128 |
+
# Get chat template
|
| 1129 |
+
chat_template = getattr(processing_class, 'chat_template', '')
|
| 1130 |
+
if chat_template == '' and is_vlm:
|
| 1131 |
+
chat_template = getattr(tokenizer, 'chat_template', '')
|
| 1132 |
+
if chat_template is None:
|
| 1133 |
+
chat_template = ''
|
| 1134 |
+
|
| 1135 |
+
# Get bos_token
|
| 1136 |
+
add_special_tokens = True
|
| 1137 |
+
bos_token_1 = getattr(processing_class, 'bos_token', None)
|
| 1138 |
+
bos_token_2 = getattr(tokenizer, 'bos_token', None)
|
| 1139 |
+
bos_token = bos_token_1 or bos_token_2
|
| 1140 |
+
|
| 1141 |
+
if bos_token is not None:
|
| 1142 |
+
if test_text.startswith(bos_token) or bos_token in chat_template:
|
| 1143 |
+
add_special_tokens = False
|
| 1144 |
+
print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
|
| 1145 |
+
pass
|
| 1146 |
+
|
| 1147 |
+
# Create tokenize function
|
| 1148 |
+
def _tokenize(example):
|
| 1149 |
+
return tokenizer(
|
| 1150 |
+
example[dataset_text_field] if not do_formatting_func else formatting_func(example),
|
| 1151 |
+
truncation = do_truncation,
|
| 1152 |
+
max_length = max_seq_length,
|
| 1153 |
+
return_token_type_ids = _needs_token_type_ids,
|
| 1154 |
+
add_special_tokens = add_special_tokens,
|
| 1155 |
+
)
|
| 1156 |
+
pass
|
| 1157 |
+
|
| 1158 |
+
if not isinstance(dataset, IterableDataset):
|
| 1159 |
+
import multiprocessing as _mp
|
| 1160 |
+
if _mp.get_start_method() != 'fork':
|
| 1161 |
+
dataset_num_proc = None
|
| 1162 |
+
else:
|
| 1163 |
+
dataset_num_proc = getattr(args, "dataset_num_proc", None)
|
| 1164 |
+
if dataset_num_proc is None:
|
| 1165 |
+
import psutil
|
| 1166 |
+
dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
|
| 1167 |
+
memory_gb_left = psutil.virtual_memory().available / (1024**3)
|
| 1168 |
+
if memory_gb_left <= 2:
|
| 1169 |
+
dataset_num_proc = 1
|
| 1170 |
+
else:
|
| 1171 |
+
dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
|
| 1172 |
+
map_kwargs["num_proc"] = dataset_num_proc
|
| 1173 |
+
else:
|
| 1174 |
+
map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
|
| 1175 |
+
|
| 1176 |
+
if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
|
| 1177 |
+
import warnings as _w
|
| 1178 |
+
with _w.catch_warnings():
|
| 1179 |
+
_w.filterwarnings("ignore", message=".*couldn't be hashed properly.*")
|
| 1180 |
+
dataset = dataset.map(_tokenize, batched = True, remove_columns = list(column_names), **map_kwargs)
|
| 1181 |
+
|
| 1182 |
+
# If VLM, switch data collator since .pad is needed!
|
| 1183 |
+
if is_vlm and not hasattr(processing_class, "pad"):
|
| 1184 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
| 1185 |
+
self.data_collator = data_collator
|
| 1186 |
+
pass
|
| 1187 |
+
pass
|
| 1188 |
+
if packing:
|
| 1189 |
+
# Try using new packing which works in TRL
|
| 1190 |
+
try:
|
| 1191 |
+
pack_dataset
|
| 1192 |
+
except:
|
| 1193 |
+
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
| 1194 |
+
return dataset
|
| 1195 |
+
|
| 1196 |
+
if max_seq_length == 0:
|
| 1197 |
+
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
| 1198 |
+
|
| 1199 |
+
if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
|
| 1200 |
+
dataset = pack_dataset(
|
| 1201 |
+
dataset.select_columns(used_column_names),
|
| 1202 |
+
max_seq_length,
|
| 1203 |
+
getattr(args, "packing_strategy", "bfd"),
|
| 1204 |
+
map_kwargs,
|
| 1205 |
+
)
|
| 1206 |
+
pass
|
| 1207 |
+
return dataset
|
| 1208 |
+
|
| 1209 |
+
def _set_signature_columns_if_needed(self):
|
| 1210 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 1211 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
|
| 1212 |
+
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
|
| 1213 |
+
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
|
| 1214 |
+
if self._signature_columns is None:
|
| 1215 |
+
if self._is_vision_dataset:
|
| 1216 |
+
self._signature_columns = ["messages", "prompt", "completion", "images", "input_ids", "labels", "attention_mask", "seq_lengths", "completion_mask", "assistant_masks"]
|
| 1217 |
+
else:
|
| 1218 |
+
self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]
|
| 1219 |
+
|
| 1220 |
+
def compute_loss(
|
| 1221 |
+
self, model, inputs, return_outputs = False, num_items_in_batch = None
|
| 1222 |
+
):
|
| 1223 |
+
outputs = super().compute_loss(
|
| 1224 |
+
model,
|
| 1225 |
+
inputs,
|
| 1226 |
+
return_outputs = return_outputs,
|
| 1227 |
+
num_items_in_batch = num_items_in_batch,
|
| 1228 |
+
)
|
| 1229 |
+
return outputs
|
| 1230 |
+
|
| 1231 |
+
# Override training step to add activation offloading context.
|
| 1232 |
+
def training_step(self, *args, **kwargs):
|
| 1233 |
+
with self.maybe_activation_offload_context:
|
| 1234 |
+
return super().training_step(*args, **kwargs)
|
| 1235 |
+
|
| 1236 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1237 |
+
mode = "train" if self.model.training else "eval"
|
| 1238 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
| 1239 |
+
|
| 1240 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 1241 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 1242 |
+
if mode == "eval":
|
| 1243 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 1244 |
+
|
| 1245 |
+
logs.update(metrics)
|
| 1246 |
+
super().log(logs, start_time)
|
| 1247 |
+
self._metrics[mode].clear()
|
| 1248 |
+
|
| 1249 |
+
# Ensure the model card is saved along with the checkpoint
|
| 1250 |
+
def _save_checkpoint(self, model, trial):
|
| 1251 |
+
if self.args.hub_model_id is None:
|
| 1252 |
+
model_name = Path(self.args.output_dir).name
|
| 1253 |
+
else:
|
| 1254 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 1255 |
+
self.create_model_card(model_name=model_name)
|
| 1256 |
+
super()._save_checkpoint(model, trial)
|
| 1257 |
+
class UnslothSFTTrainer(_UnslothSFTTrainer):
|
| 1258 |
+
"""
|
| 1259 |
+
|
| 1260 |
+
Trainer for Supervised Fine-Tuning (SFT) method.
|
| 1261 |
+
|
| 1262 |
+
This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
|
| 1263 |
+
|
| 1264 |
+
Example:
|
| 1265 |
+
|
| 1266 |
+
```python
|
| 1267 |
+
from datasets import load_dataset
|
| 1268 |
+
from trl import SFTTrainer
|
| 1269 |
+
|
| 1270 |
+
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
|
| 1271 |
+
|
| 1272 |
+
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
|
| 1273 |
+
trainer.train()
|
| 1274 |
+
```
|
| 1275 |
+
|
| 1276 |
+
Args:
|
| 1277 |
+
model (`Union[str, PreTrainedModel]`):
|
| 1278 |
+
Model to be trained. Can be either:
|
| 1279 |
+
|
| 1280 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
| 1281 |
+
path to a *directory* containing model weights saved using
|
| 1282 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
| 1283 |
+
using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
|
| 1284 |
+
config) with the keyword arguments in `args.model_init_kwargs`.
|
| 1285 |
+
- A [`~transformers.PreTrainedModel`] object.
|
| 1286 |
+
If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss
|
| 1287 |
+
as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`.
|
| 1288 |
+
args ([`SFTConfig`], *optional*):
|
| 1289 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 1290 |
+
data_collator ([`~transformers.DataCollator`], *optional*):
|
| 1291 |
+
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
| 1292 |
+
Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model
|
| 1293 |
+
and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
|
| 1294 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 1295 |
+
Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
|
| 1296 |
+
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
|
| 1297 |
+
|
| 1298 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 1299 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 1300 |
+
and content).
|
| 1301 |
+
|
| 1302 |
+
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
| 1303 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 1304 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 1305 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
|
| 1306 |
+
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
| 1307 |
+
with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set.
|
| 1308 |
+
If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.
|
| 1309 |
+
compute_loss_func (`Callable`, *optional*):
|
| 1310 |
+
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
|
| 1311 |
+
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss
|
| 1312 |
+
function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618)
|
| 1313 |
+
used by [`Trainer`].
|
| 1314 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1315 |
+
The function that will be used to compute metrics at evaluation. Must take a
|
| 1316 |
+
[`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
|
| 1317 |
+
[`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean
|
| 1318 |
+
`compute_result` argument. This will be triggered after the last eval batch to signal that the function
|
| 1319 |
+
needs to calculate and return the global summary statistics rather than accumulating the batch-level
|
| 1320 |
+
statistics.
|
| 1321 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
|
| 1322 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
|
| 1323 |
+
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 1324 |
+
|
| 1325 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 1326 |
+
method.
|
| 1327 |
+
optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
|
| 1328 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
|
| 1329 |
+
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
|
| 1330 |
+
optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
|
| 1331 |
+
A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
|
| 1332 |
+
`args`. Incompatible with the `optimizers` argument.
|
| 1333 |
+
|
| 1334 |
+
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
|
| 1335 |
+
initializing the Trainer.
|
| 1336 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
|
| 1337 |
+
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
| 1338 |
+
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
| 1339 |
+
by this function will be reflected in the predictions received by `compute_metrics`.
|
| 1340 |
+
|
| 1341 |
+
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
| 1342 |
+
peft_config ([`~peft.PeftConfig`], *optional*):
|
| 1343 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 1344 |
+
formatting_func (`Callable`, *optional*):
|
| 1345 |
+
Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly
|
| 1346 |
+
converts the dataset into a [language modeling](#language-modeling) type.
|
| 1347 |
+
|
| 1348 |
+
"""
|
| 1349 |
+
def __init__(
|
| 1350 |
+
self,
|
| 1351 |
+
model,
|
| 1352 |
+
args = None,
|
| 1353 |
+
data_collator = None,
|
| 1354 |
+
train_dataset = None,
|
| 1355 |
+
eval_dataset = None,
|
| 1356 |
+
processing_class = None,
|
| 1357 |
+
compute_loss_func = None,
|
| 1358 |
+
compute_metrics = None,
|
| 1359 |
+
callbacks = None,
|
| 1360 |
+
optimizer_cls_and_kwargs = None,
|
| 1361 |
+
preprocess_logits_for_metrics = None,
|
| 1362 |
+
peft_config = None,
|
| 1363 |
+
formatting_func = None,
|
| 1364 |
+
**kwargs
|
| 1365 |
+
):
|
| 1366 |
+
if args is None: args = UnslothSFTConfig()
|
| 1367 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1368 |
+
if type(use_bf16) is not bool: use_bf16 = False
|
| 1369 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1370 |
+
if type(use_fp16) is not bool: use_fp16 = False
|
| 1371 |
+
force_float32 = False
|
| 1372 |
+
full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
|
| 1373 |
+
if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
|
| 1374 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1375 |
+
force_float32 = True
|
| 1376 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1377 |
+
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
| 1378 |
+
if dtype is None: dtype = model.get_input_embeddings().weight.dtype
|
| 1379 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1380 |
+
dtype = _get_dtype(dtype)
|
| 1381 |
+
float16 = dtype == torch.float16
|
| 1382 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1383 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1384 |
+
if force_float32:
|
| 1385 |
+
# Forced float32 training
|
| 1386 |
+
args.fp16 = False
|
| 1387 |
+
args.bf16 = False
|
| 1388 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1389 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1390 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1391 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1392 |
+
# Mixed precision training
|
| 1393 |
+
args.fp16 = float16
|
| 1394 |
+
args.bf16 = not float16
|
| 1395 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1396 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
|
| 1397 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1398 |
+
elif mixed_precision_dtype == 'bfloat16':
|
| 1399 |
+
# Both False since bfloat16 full finetuning doesn't do any autocasting.
|
| 1400 |
+
args.fp16 = False
|
| 1401 |
+
args.bf16 = False
|
| 1402 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1403 |
+
if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
|
| 1404 |
+
# args.mixed_precision is a new argument which needs to be set now
|
| 1405 |
+
|
| 1406 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1407 |
+
args.eval_strategy = 'steps'
|
| 1408 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1409 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1410 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1411 |
+
from transformers import __version__ as transformers_version
|
| 1412 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1413 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1414 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1415 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1416 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1417 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1418 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1419 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1420 |
+
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1421 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1422 |
+
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1423 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1424 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1425 |
+
if force_float32:
|
| 1426 |
+
args.bf16_full_eval = False
|
| 1427 |
+
args.fp16_full_eval = False
|
| 1428 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1429 |
+
args.bf16_full_eval = True
|
| 1430 |
+
args.fp16_full_eval = False
|
| 1431 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1432 |
+
args.bf16_full_eval = args.bf16
|
| 1433 |
+
args.fp16_full_eval = args.fp16
|
| 1434 |
+
_output_logits = False
|
| 1435 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1436 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1437 |
+
if _output_logits:
|
| 1438 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1439 |
+
if model is not None:
|
| 1440 |
+
_warnings_issued = getattr(model, 'warnings_issued', None)
|
| 1441 |
+
if _warnings_issued is None:
|
| 1442 |
+
model.warnings_issued = {}
|
| 1443 |
+
elif not isinstance(_warnings_issued, dict):
|
| 1444 |
+
try:
|
| 1445 |
+
model.warnings_issued = dict(_warnings_issued)
|
| 1446 |
+
except Exception:
|
| 1447 |
+
model.warnings_issued = {}
|
| 1448 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1449 |
+
pass
|
| 1450 |
+
else:
|
| 1451 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1452 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1453 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1454 |
+
max_seq_length = model.max_seq_length
|
| 1455 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1456 |
+
elif args_max_seq_length is not None and model_max_seq_length is not None:
|
| 1457 |
+
if args_max_seq_length > model_max_seq_length:
|
| 1458 |
+
print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
|
| 1459 |
+
'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
|
| 1460 |
+
args.max_seq_length = model_max_seq_length
|
| 1461 |
+
if 'max_length' not in locals() and not hasattr(args, 'max_length'):
|
| 1462 |
+
pass
|
| 1463 |
+
else:
|
| 1464 |
+
if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:
|
| 1465 |
+
if hasattr(args, 'max_length'):
|
| 1466 |
+
args.max_length = args.max_seq_length
|
| 1467 |
+
max_length = args.max_length
|
| 1468 |
+
else:
|
| 1469 |
+
model_max_length = getattr(model, 'max_seq_length', None)
|
| 1470 |
+
if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
|
| 1471 |
+
if model_max_length is not None:
|
| 1472 |
+
args.max_length = model_max_length
|
| 1473 |
+
max_length = args.max_length
|
| 1474 |
+
elif hasattr(args, 'max_length') and args.max_length is not None:
|
| 1475 |
+
max_length = args.max_length
|
| 1476 |
+
# if we are here, then we are in a weird case where max_length is set but max_seq_length is not set
|
| 1477 |
+
setattr(model, 'max_seq_length', max_length)
|
| 1478 |
+
else:
|
| 1479 |
+
print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')
|
| 1480 |
+
args.max_length = 1024
|
| 1481 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1482 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1483 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1484 |
+
if 'processing_class' in locals():
|
| 1485 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1486 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1487 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1488 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1489 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1490 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1491 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1492 |
+
__tokenizer,
|
| 1493 |
+
mlm = False,
|
| 1494 |
+
mlm_probability = 0.0,
|
| 1495 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1496 |
+
)
|
| 1497 |
+
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1498 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1499 |
+
__tokenizer,
|
| 1500 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1501 |
+
)
|
| 1502 |
+
else:
|
| 1503 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1504 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1505 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1506 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1507 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1508 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1509 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 1510 |
+
__tokenizer.tokenizer,
|
| 1511 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1512 |
+
)
|
| 1513 |
+
else:
|
| 1514 |
+
data_collator = TransformersDataCollatorForLanguageModeling(
|
| 1515 |
+
__tokenizer.tokenizer,
|
| 1516 |
+
mlm = False,
|
| 1517 |
+
mlm_probability = 0.0,
|
| 1518 |
+
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
| 1519 |
+
)
|
| 1520 |
+
other_metrics = []
|
| 1521 |
+
|
| 1522 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1523 |
+
PatchRLStatistics('sft_trainer', other_metrics)
|
| 1524 |
+
IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
|
| 1525 |
+
from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
|
| 1526 |
+
from unsloth_zoo.training_utils import fix_zero_training_loss
|
| 1527 |
+
if 'tokenizer' not in locals(): tokenizer = processing_class
|
| 1528 |
+
fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
|
| 1529 |
+
fix_zero_training_loss(model, tokenizer, train_dataset)
|
| 1530 |
+
|
| 1531 |
+
# [TODO] Fix up DataParallel multiplying batch sizes
|
| 1532 |
+
# [TODO] DDP works, but DP seems to not work? [TODO]
|
| 1533 |
+
if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
|
| 1534 |
+
if getattr(args, "_n_gpu", 1) != 1:
|
| 1535 |
+
args._n_gpu = 1
|
| 1536 |
+
if "model" in locals() and hasattr(model, "for_training"):
|
| 1537 |
+
model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
|
| 1538 |
+
super().__init__(
|
| 1539 |
+
model = model,
|
| 1540 |
+
args = args,
|
| 1541 |
+
data_collator = data_collator,
|
| 1542 |
+
train_dataset = train_dataset,
|
| 1543 |
+
eval_dataset = eval_dataset,
|
| 1544 |
+
processing_class = processing_class,
|
| 1545 |
+
compute_loss_func = compute_loss_func,
|
| 1546 |
+
compute_metrics = compute_metrics,
|
| 1547 |
+
callbacks = callbacks,
|
| 1548 |
+
optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
|
| 1549 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1550 |
+
peft_config = peft_config,
|
| 1551 |
+
formatting_func = formatting_func,**kwargs)
|
| 1552 |
+
if "model" in locals() and hasattr(model, "for_inference"):
|
| 1553 |
+
model.for_inference()
|
| 1554 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1555 |
+
self.neftune_hook_handle.remove()
|
| 1556 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1557 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1558 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1559 |
+
pass
|
| 1560 |
+
if hasattr(self, 'accelerator'):
|
| 1561 |
+
scaler = self.accelerator.scaler
|
| 1562 |
+
current_model = model
|
| 1563 |
+
while hasattr(current_model, 'model'):
|
| 1564 |
+
current_model.accelerator_scaler = scaler
|
| 1565 |
+
current_model = current_model.model
|
| 1566 |
+
current_model.accelerator_scaler = scaler
|
| 1567 |
+
pass
|
| 1568 |
+
if hasattr(self, 'train'):
|
| 1569 |
+
self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
|
| 1570 |
+
pass
|
| 1571 |
+
if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
|
| 1572 |
+
_vllm_tok = self.llm.get_tokenizer()
|
| 1573 |
+
_pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
|
| 1574 |
+
if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
|
| 1575 |
+
_vllm_tok.chat_template = _pc.chat_template
|
| 1576 |
+
pass
|
| 1577 |
+
|
| 1578 |
+
pass
|
| 1579 |
+
|
| 1580 |
+
|
| 1581 |
+
if hasattr(logger, "addFilter"):
|
| 1582 |
+
import logging
|
| 1583 |
+
class HideLoggingMessage(logging.Filter):
|
| 1584 |
+
def __init__(self, text): self.text = text
|
| 1585 |
+
def filter(self, x): return not (self.text in x.getMessage())
|
| 1586 |
+
pass
|
| 1587 |
+
logger.addFilter(HideLoggingMessage("`use_cache=True`"))
|
| 1588 |
+
|
unsloth_compiled_cache/moe_utils.py
ADDED
|
@@ -0,0 +1,1320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 2 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This program is free software: you can redistribute it and/or modify
|
| 5 |
+
# it under the terms of the GNU Affero General Public License as published
|
| 6 |
+
# by the Free Software Foundation, either version 3 of the License, or
|
| 7 |
+
# (at your option) any later version.
|
| 8 |
+
#
|
| 9 |
+
# This program is distributed in the hope that it will be useful,
|
| 10 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 11 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 12 |
+
# GNU Affero General Public License for more details.
|
| 13 |
+
#
|
| 14 |
+
# You should have received a copy of the GNU Affero General Public License
|
| 15 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import os
|
| 19 |
+
import shutil
|
| 20 |
+
import sys
|
| 21 |
+
import importlib.util
|
| 22 |
+
from typing import Optional, Tuple
|
| 23 |
+
from torch.autograd import Function
|
| 24 |
+
|
| 25 |
+
# Get compile location
|
| 26 |
+
UNSLOTH_COMPILE_LOCATION = os.environ.get(
|
| 27 |
+
"UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _get_compile_location() -> str:
|
| 32 |
+
return os.path.abspath(
|
| 33 |
+
os.environ.get("UNSLOTH_COMPILE_LOCATION", UNSLOTH_COMPILE_LOCATION)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _log_info(message: str):
|
| 38 |
+
if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1":
|
| 39 |
+
print(message)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def install_to_cache(source_path, destination_filename=None):
|
| 43 |
+
"""
|
| 44 |
+
Copies a file to the unsloth_compiled_cache directory
|
| 45 |
+
to ensure it is available for compiled modules.
|
| 46 |
+
"""
|
| 47 |
+
compile_location = _get_compile_location()
|
| 48 |
+
if not os.path.exists(compile_location):
|
| 49 |
+
try:
|
| 50 |
+
os.makedirs(compile_location)
|
| 51 |
+
except:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
current_file = os.path.abspath(source_path)
|
| 55 |
+
if destination_filename is None:
|
| 56 |
+
destination_filename = os.path.basename(current_file)
|
| 57 |
+
|
| 58 |
+
destination = os.path.abspath(os.path.join(compile_location, destination_filename))
|
| 59 |
+
|
| 60 |
+
# If source and dest are different, copy.
|
| 61 |
+
if current_file != destination:
|
| 62 |
+
try:
|
| 63 |
+
shutil.copy(current_file, destination)
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
install_to_cache(__file__, "moe_utils.py")
|
| 69 |
+
|
| 70 |
+
_CACHED_FORWARD_MOE_BACKEND = None
|
| 71 |
+
_CACHED_MOE_UTILS_MODULE = None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _load_cached_moe_utils_module():
|
| 75 |
+
global _CACHED_MOE_UTILS_MODULE
|
| 76 |
+
|
| 77 |
+
cache_file = os.path.abspath(os.path.join(_get_compile_location(), "moe_utils.py"))
|
| 78 |
+
current_file = os.path.abspath(__file__)
|
| 79 |
+
if not os.path.isfile(cache_file) or cache_file == current_file:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
module_name = "unsloth_cached_moe_utils"
|
| 84 |
+
module = sys.modules.get(module_name, None)
|
| 85 |
+
if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
|
| 86 |
+
_CACHED_MOE_UTILS_MODULE = module
|
| 87 |
+
return module
|
| 88 |
+
|
| 89 |
+
spec = importlib.util.spec_from_file_location(module_name, cache_file)
|
| 90 |
+
if spec is None or spec.loader is None:
|
| 91 |
+
return None
|
| 92 |
+
module = importlib.util.module_from_spec(spec)
|
| 93 |
+
sys.modules[module_name] = module
|
| 94 |
+
spec.loader.exec_module(module)
|
| 95 |
+
_CACHED_MOE_UTILS_MODULE = module
|
| 96 |
+
return module
|
| 97 |
+
except Exception:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_forward_moe_backend():
|
| 102 |
+
"""
|
| 103 |
+
Resolve forward_moe_backend from the compiled cache copy when available.
|
| 104 |
+
Falls back to the local module definition.
|
| 105 |
+
"""
|
| 106 |
+
global _CACHED_FORWARD_MOE_BACKEND
|
| 107 |
+
module = _load_cached_moe_utils_module()
|
| 108 |
+
if module is not None and hasattr(module, "forward_moe_backend"):
|
| 109 |
+
_CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend
|
| 110 |
+
return _CACHED_FORWARD_MOE_BACKEND
|
| 111 |
+
|
| 112 |
+
_CACHED_FORWARD_MOE_BACKEND = forward_moe_backend
|
| 113 |
+
return _CACHED_FORWARD_MOE_BACKEND
|
| 114 |
+
|
| 115 |
+
# ============================================================================
|
| 116 |
+
# Grouped MM wrapper
|
| 117 |
+
# ============================================================================
|
| 118 |
+
# Simple wrapper around torch._grouped_mm that ensures contiguous inputs.
|
| 119 |
+
# Native backward works correctly - no custom autograd needed.
|
| 120 |
+
# ============================================================================
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _grouped_mm_with_backward_fix(
|
| 124 |
+
inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
|
| 125 |
+
) -> torch.Tensor:
|
| 126 |
+
"""
|
| 127 |
+
Grouped matmul with working backward pass.
|
| 128 |
+
|
| 129 |
+
Uses native torch._grouped_mm with contiguous inputs for correct gradients.
|
| 130 |
+
"""
|
| 131 |
+
return torch._grouped_mm(inputs, weight, offs=offsets)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Global flag to check if grouped GEMM is available
|
| 135 |
+
_GROUPED_GEMM_AVAILABLE = None
|
| 136 |
+
_TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm")
|
| 137 |
+
|
| 138 |
+
# Check if GPU supports torch._grouped_mm (verified via runtime check)
|
| 139 |
+
_TORCH_GROUPED_MM_SUPPORTED = None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _check_torch_grouped_mm_supported():
|
| 143 |
+
"""
|
| 144 |
+
Check if torch._grouped_mm is actually supported on the current GPU.
|
| 145 |
+
We check for existence and verify with a dummy call.
|
| 146 |
+
A runtime probe is the only reliable check.
|
| 147 |
+
"""
|
| 148 |
+
global _TORCH_GROUPED_MM_SUPPORTED
|
| 149 |
+
if _TORCH_GROUPED_MM_SUPPORTED is not None: return _TORCH_GROUPED_MM_SUPPORTED
|
| 150 |
+
|
| 151 |
+
if not _TORCH_GROUPED_MM_AVAILABLE:
|
| 152 |
+
_TORCH_GROUPED_MM_SUPPORTED = False
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
if not torch.cuda.is_available():
|
| 156 |
+
_TORCH_GROUPED_MM_SUPPORTED = False
|
| 157 |
+
return False
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
# Attempt a dummy grouped_mm call to verify support.
|
| 161 |
+
# This handles cases where the symbol exists but hardware is unsupported (e.g. < H100).
|
| 162 |
+
# It also allows support on newer hardware or backports without code changes.
|
| 163 |
+
device = torch.cuda.current_device()
|
| 164 |
+
dtype = torch.float16
|
| 165 |
+
|
| 166 |
+
# Minimal dummy data: 1 expert, 1 token, dim 8 (safe alignment)
|
| 167 |
+
x = torch.ones((1, 8), device=device, dtype=dtype)
|
| 168 |
+
w = torch.ones((1, 8, 8), device=device, dtype=dtype)
|
| 169 |
+
offs = torch.tensor([1], device=device, dtype=torch.int32)
|
| 170 |
+
|
| 171 |
+
torch._grouped_mm(x, w, offs=offs)
|
| 172 |
+
del x, w, offs
|
| 173 |
+
_TORCH_GROUPED_MM_SUPPORTED = True
|
| 174 |
+
except Exception:
|
| 175 |
+
_TORCH_GROUPED_MM_SUPPORTED = False
|
| 176 |
+
|
| 177 |
+
return _TORCH_GROUPED_MM_SUPPORTED
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
_TRITON_ALLOCATOR_INITIALIZED = False
|
| 181 |
+
_PERSISTENT_BUFFER = None
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _init_triton_allocator():
|
| 185 |
+
"""
|
| 186 |
+
Initialize a persistent Triton allocator to avoid memory allocation overhead per call.
|
| 187 |
+
This significantly reduces GPU utilization fluctuation.
|
| 188 |
+
"""
|
| 189 |
+
global _TRITON_ALLOCATOR_INITIALIZED, _PERSISTENT_BUFFER
|
| 190 |
+
if _TRITON_ALLOCATOR_INITIALIZED: return
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
import triton
|
| 194 |
+
|
| 195 |
+
# Create a persistent buffer that grows as needed
|
| 196 |
+
# This avoids allocating new memory on every kernel call
|
| 197 |
+
|
| 198 |
+
def persistent_alloc_fn(size: int, alignment: int, stream):
|
| 199 |
+
global _PERSISTENT_BUFFER
|
| 200 |
+
# Round up size to avoid frequent reallocations
|
| 201 |
+
# Round to nearest 128 bytes for alignment
|
| 202 |
+
rounded_size = ((size + 128 - 1) // 128) * 128
|
| 203 |
+
|
| 204 |
+
if (
|
| 205 |
+
_PERSISTENT_BUFFER is None
|
| 206 |
+
or _PERSISTENT_BUFFER.numel() * _PERSISTENT_BUFFER.element_size()
|
| 207 |
+
< rounded_size
|
| 208 |
+
):
|
| 209 |
+
# Allocate with small headroom (10%) to reduce reallocations
|
| 210 |
+
# Use ByteTensor (uint8) for raw byte storage
|
| 211 |
+
_PERSISTENT_BUFFER = torch.empty(
|
| 212 |
+
int(rounded_size * 1.1), device="cuda", dtype=torch.uint8
|
| 213 |
+
)
|
| 214 |
+
_PERSISTENT_BUFFER.__hibernate__ = {"type": "ignore"}
|
| 215 |
+
return _PERSISTENT_BUFFER
|
| 216 |
+
|
| 217 |
+
triton.set_allocator(persistent_alloc_fn)
|
| 218 |
+
triton._unsloth_allocator_set = True
|
| 219 |
+
_TRITON_ALLOCATOR_INITIALIZED = True
|
| 220 |
+
except Exception:
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _check_grouped_gemm_available():
|
| 225 |
+
"""Check if Unsloth grouped GEMM kernels are available."""
|
| 226 |
+
if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False
|
| 227 |
+
|
| 228 |
+
global _GROUPED_GEMM_AVAILABLE
|
| 229 |
+
if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm, supports_tma
|
| 233 |
+
_GROUPED_GEMM_AVAILABLE = True
|
| 234 |
+
_init_triton_allocator()
|
| 235 |
+
except (ImportError, ModuleNotFoundError):
|
| 236 |
+
_GROUPED_GEMM_AVAILABLE = False
|
| 237 |
+
return _GROUPED_GEMM_AVAILABLE
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
from functools import lru_cache
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@lru_cache(maxsize=1)
|
| 244 |
+
def select_moe_backend():
|
| 245 |
+
"""
|
| 246 |
+
Selects the MoE backend based on UNSLOTH_MOE_BACKEND environment variable and availability.
|
| 247 |
+
Choices: "grouped_mm", "unsloth_triton", "native_torch".
|
| 248 |
+
Default if unspecified: "grouped_mm".
|
| 249 |
+
"""
|
| 250 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 251 |
+
|
| 252 |
+
requested = os.environ.get("UNSLOTH_MOE_BACKEND")
|
| 253 |
+
if requested:
|
| 254 |
+
if requested == "grouped_mm" and _check_torch_grouped_mm_supported():
|
| 255 |
+
return "grouped_mm"
|
| 256 |
+
if requested == "unsloth_triton" and _check_grouped_gemm_available():
|
| 257 |
+
return "unsloth_triton"
|
| 258 |
+
if requested == "native_torch":
|
| 259 |
+
return "native_torch"
|
| 260 |
+
_log_info(f"Unsloth: '{requested}' backend requested but is not available. Falling back to next available.")
|
| 261 |
+
|
| 262 |
+
if _check_torch_grouped_mm_supported():
|
| 263 |
+
_log_info("Unsloth: Using MoE backend 'grouped_mm'")
|
| 264 |
+
return "grouped_mm"
|
| 265 |
+
if _check_grouped_gemm_available():
|
| 266 |
+
_log_info("Unsloth: Using MoE backend 'unsloth_triton'")
|
| 267 |
+
return "unsloth_triton"
|
| 268 |
+
return "native_torch"
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def forward_moe_backend(
|
| 272 |
+
self,
|
| 273 |
+
hidden_states: torch.Tensor,
|
| 274 |
+
top_k_index: torch.Tensor,
|
| 275 |
+
top_k_weights: torch.Tensor,
|
| 276 |
+
) -> torch.Tensor:
|
| 277 |
+
"""
|
| 278 |
+
Dispatch MoE forward to the selected backend.
|
| 279 |
+
Centralizes backend selection to keep model-specific patches minimal.
|
| 280 |
+
"""
|
| 281 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 282 |
+
|
| 283 |
+
backend = select_moe_backend()
|
| 284 |
+
if backend == "grouped_mm":
|
| 285 |
+
return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights)
|
| 286 |
+
if backend == "unsloth_triton":
|
| 287 |
+
return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights)
|
| 288 |
+
return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@torch.no_grad()
|
| 292 |
+
def _get_routing_indices(selected_experts, num_experts):
|
| 293 |
+
"""
|
| 294 |
+
Compute token→expert mapping for grouped GEMM.
|
| 295 |
+
Uses bincount instead of histc to avoid float conversion overhead.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
token_counts_by_expert: (num_experts,) token counts per expert
|
| 299 |
+
gather_indices: (total_tokens,) indices for gathering tokens in expert order
|
| 300 |
+
"""
|
| 301 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 302 |
+
|
| 303 |
+
flat_experts = selected_experts.view(-1)
|
| 304 |
+
|
| 305 |
+
# bincount is faster than histc since it doesn't require float conversion
|
| 306 |
+
token_counts_by_expert = torch.bincount(flat_experts, minlength=num_experts).to(torch.int32)
|
| 307 |
+
|
| 308 |
+
# argsort with stable=True preserves order within each expert
|
| 309 |
+
gather_indices = flat_experts.argsort(stable=True)
|
| 310 |
+
|
| 311 |
+
return token_counts_by_expert, gather_indices
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _silu_and_mul(x):
|
| 315 |
+
"""Fused SiLU activation and element-wise multiply for gate/up projections."""
|
| 316 |
+
gate, up = x.chunk(2, dim=-1)
|
| 317 |
+
return F.silu(gate) * up
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ============================================================================
|
| 321 |
+
# Separated LoRA Helper Functions
|
| 322 |
+
# ============================================================================
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _has_lora_adapters(param) -> bool:
|
| 326 |
+
"""Check if parameter has active LoRA adapters (PEFT ParamWrapper)."""
|
| 327 |
+
# Check if this is a PEFT LoRA wrapper
|
| 328 |
+
if not hasattr(param, "lora_A") or not hasattr(param, "lora_B"):
|
| 329 |
+
return False
|
| 330 |
+
if hasattr(param, "disable_adapters") and param.disable_adapters:
|
| 331 |
+
return False
|
| 332 |
+
if hasattr(param, "merged") and param.merged:
|
| 333 |
+
return False
|
| 334 |
+
return len(param.lora_A) > 0
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _extract_lora_from_wrapper(
|
| 338 |
+
wrapper, adapter_name: str = "default", experts_module=None
|
| 339 |
+
) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]:
|
| 340 |
+
"""
|
| 341 |
+
Extract LoRA weights from PEFT ParamWrapper for MoE separated computation.
|
| 342 |
+
|
| 343 |
+
PEFT ParamWrapper for 3D parameters creates:
|
| 344 |
+
- lora_A: nn.Linear(in_dim, E*R) -> weight: (E*R, in_dim)
|
| 345 |
+
- lora_B: nn.Linear(E*R, out_dim) -> weight: (out_dim, E*R)
|
| 346 |
+
|
| 347 |
+
For grouped_mm: X @ first_weight @ second_weight
|
| 348 |
+
|
| 349 |
+
STANDARD FORMAT (Qwen3-MoE): weights stored as (E, out_dim, in_dim) for F.linear
|
| 350 |
+
gate_up_proj: (E, 2*I, H) - input X is (N, H), output is (N, 2*I)
|
| 351 |
+
down_proj: (E, H, I) - input X is (N, I), output is (N, H)
|
| 352 |
+
|
| 353 |
+
For gate_up with (E, 2*I, H):
|
| 354 |
+
lora_A: (E*R, H), lora_B: (2*I, E*R)
|
| 355 |
+
Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
|
| 356 |
+
first_weight from lora_A: (E*R, H) -> (E, H, R) after view/permute
|
| 357 |
+
second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) after view/permute
|
| 358 |
+
|
| 359 |
+
TRANSPOSED FORMAT (Qwen3-VL-MoE): weights stored as (E, in_dim, out_dim) for grouped_mm
|
| 360 |
+
gate_up_proj: (E, H, 2*I) - input X is (N, H), output is (N, 2*I)
|
| 361 |
+
down_proj: (E, I, H) - input X is (N, I), output is (N, H)
|
| 362 |
+
|
| 363 |
+
For gate_up with (E, H, 2*I):
|
| 364 |
+
lora_A: (E*R, H), lora_B: (2*I, E*R)
|
| 365 |
+
Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
|
| 366 |
+
first_weight from lora_A: (E*R, H) -> (E, H, R)
|
| 367 |
+
second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I)
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
(first_weight, second_weight, scaling, num_experts) or None
|
| 371 |
+
"""
|
| 372 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 373 |
+
|
| 374 |
+
try:
|
| 375 |
+
if not hasattr(wrapper, "lora_A") or not hasattr(wrapper, "lora_B"):
|
| 376 |
+
return None
|
| 377 |
+
|
| 378 |
+
if hasattr(wrapper, "disable_adapters") and wrapper.disable_adapters:
|
| 379 |
+
return None
|
| 380 |
+
if hasattr(wrapper, "merged") and wrapper.merged:
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
if not wrapper.lora_A:
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
if adapter_name not in wrapper.lora_A:
|
| 387 |
+
adapter_name = list(wrapper.lora_A.keys())[0]
|
| 388 |
+
|
| 389 |
+
lora_A_module = wrapper.lora_A[adapter_name]
|
| 390 |
+
lora_B_module = wrapper.lora_B[adapter_name]
|
| 391 |
+
|
| 392 |
+
weight_A = lora_A_module.weight # (E*R, dim1)
|
| 393 |
+
weight_B = lora_B_module.weight # (dim2, E*R)
|
| 394 |
+
scaling = wrapper.scaling[adapter_name]
|
| 395 |
+
num_experts = getattr(wrapper, "num_experts", 1)
|
| 396 |
+
|
| 397 |
+
# GET EXPERTS MODULE TO CHECK FOR REGISTERED EXTRACTOR
|
| 398 |
+
if experts_module is None:
|
| 399 |
+
experts_module = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
|
| 400 |
+
|
| 401 |
+
# Check for model-specific LoRA extractor attached to the experts module
|
| 402 |
+
extractor_fn = getattr(experts_module, "_unsloth_lora_extractor_fn", None)
|
| 403 |
+
|
| 404 |
+
if extractor_fn is not None:
|
| 405 |
+
return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts)
|
| 406 |
+
|
| 407 |
+
# DEFAULT BEHAVIOR (Standard Format / Non-MoE)
|
| 408 |
+
if num_experts > 1:
|
| 409 |
+
total_rank = weight_A.shape[0]
|
| 410 |
+
rank_per_expert = total_rank // num_experts
|
| 411 |
+
dim1 = weight_A.shape[1]
|
| 412 |
+
dim2 = weight_B.shape[0]
|
| 413 |
+
|
| 414 |
+
# STANDARD FORMAT (Qwen3-MoE / GLM4):
|
| 415 |
+
# Base weights are (E, out_dim, in_dim) for F.linear.
|
| 416 |
+
# LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R).
|
| 417 |
+
# We need X @ (E, in_dim, R) @ (E, R, out_dim).
|
| 418 |
+
|
| 419 |
+
# first_weight: (E, in_dim, R) - from lora_A
|
| 420 |
+
# second_weight: (E, R, out_dim) - from lora_B
|
| 421 |
+
first_weight = weight_A.view(num_experts, rank_per_expert, dim1)
|
| 422 |
+
first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R)
|
| 423 |
+
|
| 424 |
+
# second_weight (B): (E, R, out_dim)
|
| 425 |
+
second_weight = weight_B.view(dim2, num_experts, rank_per_expert)
|
| 426 |
+
second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2)
|
| 427 |
+
else:
|
| 428 |
+
# Non-MoE case: return weights for X @ A.T @ B.T
|
| 429 |
+
first_weight = weight_A.T # (dim1, R)
|
| 430 |
+
second_weight = weight_B.T # (R, dim2)
|
| 431 |
+
|
| 432 |
+
return first_weight, second_weight, scaling, num_experts
|
| 433 |
+
except Exception:
|
| 434 |
+
return None
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def _extract_lora_weights(
|
| 438 |
+
param, adapter_name: str = "default", num_experts: int = None, experts_module=None
|
| 439 |
+
) -> Optional[Tuple[torch.Tensor, torch.Tensor, float]]:
|
| 440 |
+
"""
|
| 441 |
+
Extract LoRA A and B weights from PEFT ParamWrapper.
|
| 442 |
+
|
| 443 |
+
This is a compatibility wrapper around _extract_lora_from_wrapper.
|
| 444 |
+
Use _extract_lora_from_wrapper directly for new code.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
(first_weight, second_weight, scaling) for (X @ first) @ second
|
| 448 |
+
"""
|
| 449 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 450 |
+
|
| 451 |
+
# Set num_experts on param if provided, so _extract_lora_from_wrapper can use it
|
| 452 |
+
if num_experts is not None and not hasattr(param, "num_experts"):
|
| 453 |
+
param.num_experts = num_experts
|
| 454 |
+
|
| 455 |
+
result = _extract_lora_from_wrapper(param, adapter_name, experts_module=experts_module)
|
| 456 |
+
if result is None:
|
| 457 |
+
return None
|
| 458 |
+
# Return first 3 elements (first_weight, second_weight, scaling) without num_experts
|
| 459 |
+
return result[0], result[1], result[2]
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _get_base_weight(param):
|
| 463 |
+
"""Get base weight from potentially wrapped parameter or module."""
|
| 464 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 465 |
+
|
| 466 |
+
# Recursively unwrap PEFT layers
|
| 467 |
+
while hasattr(param, "base_layer"):
|
| 468 |
+
param = param.base_layer
|
| 469 |
+
|
| 470 |
+
if hasattr(param, "get_param"):
|
| 471 |
+
return param.get_param()
|
| 472 |
+
|
| 473 |
+
# Handle Modules (Linear, etc.)
|
| 474 |
+
if hasattr(param, "weight"):
|
| 475 |
+
return param.weight
|
| 476 |
+
|
| 477 |
+
return param
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _get_lora_wrapper_for_param(experts_module, param_name):
|
| 481 |
+
"""
|
| 482 |
+
Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj).
|
| 483 |
+
Uses the explicit key stored in __dict__ if available.
|
| 484 |
+
Does NOT lazily setup wrappers as that requires traversing logic not present here.
|
| 485 |
+
"""
|
| 486 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 487 |
+
|
| 488 |
+
if hasattr(experts_module, f"{param_name}_lora_wrapper"):
|
| 489 |
+
return getattr(experts_module, f"{param_name}_lora_wrapper")
|
| 490 |
+
|
| 491 |
+
# Check simple attributes if it's directly wrapped
|
| 492 |
+
if hasattr(experts_module, param_name):
|
| 493 |
+
attr = getattr(experts_module, param_name)
|
| 494 |
+
if hasattr(attr, "lora_A"): # Is a ParamWrapper
|
| 495 |
+
return attr
|
| 496 |
+
|
| 497 |
+
return None
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def native_moe_grouped_mm(
|
| 501 |
+
inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
|
| 502 |
+
) -> torch.Tensor:
|
| 503 |
+
"""
|
| 504 |
+
Native implementation using grouped_mm with backward fix.
|
| 505 |
+
|
| 506 |
+
Uses custom autograd function to avoid PyTorch's grouped_mm backward stride bug.
|
| 507 |
+
"""
|
| 508 |
+
return _grouped_mm_with_backward_fix(inputs, weight, offsets)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def _apply_lora_grouped_mm(
|
| 512 |
+
inputs: torch.Tensor,
|
| 513 |
+
lora_B: torch.Tensor,
|
| 514 |
+
lora_A: torch.Tensor,
|
| 515 |
+
offsets: torch.Tensor,
|
| 516 |
+
scaling: float,
|
| 517 |
+
grouped_mm_func=native_moe_grouped_mm,
|
| 518 |
+
) -> torch.Tensor:
|
| 519 |
+
"""
|
| 520 |
+
Apply LoRA using grouped GEMM: result = ((X @ B) @ A) * scaling
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
inputs: (total_tokens, in_dim)
|
| 524 |
+
lora_B: (num_experts, in_dim, rank) - First projection
|
| 525 |
+
lora_A: (num_experts, rank, out_dim) - Second projection
|
| 526 |
+
offsets: Grouped GEMM offsets
|
| 527 |
+
scaling: LoRA scaling factor
|
| 528 |
+
grouped_mm_func: Function to use for grouped GEMM (default: native_moe_grouped_mm)
|
| 529 |
+
"""
|
| 530 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 531 |
+
|
| 532 |
+
# 1. First Matmul (X @ B)
|
| 533 |
+
# lora_B is (E, in_dim, R)
|
| 534 |
+
# Native needs (E, in_dim, R) -> No Transpose
|
| 535 |
+
lora_intermediate = grouped_mm_func(inputs, lora_B.contiguous(), offsets)
|
| 536 |
+
|
| 537 |
+
# 2. Second Matmul (result @ A)
|
| 538 |
+
# lora_A is (E, R, out_dim)
|
| 539 |
+
# Native needs (E, R, out_dim) -> No Transpose
|
| 540 |
+
lora_delta = grouped_mm_func(lora_intermediate, lora_A.contiguous(), offsets)
|
| 541 |
+
|
| 542 |
+
return lora_delta * scaling
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def _should_use_separated_lora() -> bool:
|
| 546 |
+
"""
|
| 547 |
+
Check if separated LoRA approach should be used (default: True).
|
| 548 |
+
Set UNSLOTH_MOE_LORA_MERGED=1 to use merged approach instead.
|
| 549 |
+
"""
|
| 550 |
+
return os.environ.get("UNSLOTH_MOE_LORA_MERGED", "0") != "1"
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# ============================================================================
|
| 554 |
+
# Model-specific Weight Preprocessing Hooks
|
| 555 |
+
# ============================================================================
|
| 556 |
+
# Each model can register its own preprocessing function for weight transposition.
|
| 557 |
+
# This allows the generic backend to work with different model weight layouts.
|
| 558 |
+
|
| 559 |
+
_WEIGHT_PREPROCESSORS = {}
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def register_weight_preprocessor(model_type: str, preprocessor_fn):
|
| 563 |
+
"""
|
| 564 |
+
Register a weight preprocessor for a specific model type.
|
| 565 |
+
|
| 566 |
+
Args:
|
| 567 |
+
model_type: Model identifier (e.g., "qwen3_moe", "qwen3_vl_moe")
|
| 568 |
+
preprocessor_fn: Function(weight, proj_type, hidden_dim) -> processed_weight
|
| 569 |
+
proj_type is "gate_up" or "down"
|
| 570 |
+
"""
|
| 571 |
+
_WEIGHT_PREPROCESSORS[model_type] = preprocessor_fn
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def get_weight_preprocessor(model_type: str):
|
| 575 |
+
"""Get registered weight preprocessor for model type."""
|
| 576 |
+
return _WEIGHT_PREPROCESSORS.get(model_type)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def preprocess_weight(
|
| 580 |
+
weight: torch.Tensor, proj_type: str, hidden_dim: int, model_type=None
|
| 581 |
+
):
|
| 582 |
+
"""
|
| 583 |
+
Preprocess weight tensor for grouped_mm compatibility.
|
| 584 |
+
|
| 585 |
+
Uses model-specific preprocessor if registered, otherwise uses default logic.
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
weight: Weight tensor (E, dim1, dim2) or similar
|
| 589 |
+
proj_type: "gate_up" or "down"
|
| 590 |
+
hidden_dim: Hidden dimension for shape inference
|
| 591 |
+
model_type: Optional model type to use specific preprocessor
|
| 592 |
+
|
| 593 |
+
Returns:
|
| 594 |
+
Weight tensor in (E, in_dim, out_dim) format for grouped_mm
|
| 595 |
+
"""
|
| 596 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 597 |
+
|
| 598 |
+
if model_type and model_type in _WEIGHT_PREPROCESSORS:
|
| 599 |
+
return _WEIGHT_PREPROCESSORS[model_type](weight, proj_type, hidden_dim)
|
| 600 |
+
|
| 601 |
+
# Default preprocessing: check if transposition is needed
|
| 602 |
+
if proj_type == "gate_up":
|
| 603 |
+
# For gate_up, we need (E, hidden_dim, 2*intermediate)
|
| 604 |
+
if weight.shape[1] == hidden_dim:
|
| 605 |
+
return weight
|
| 606 |
+
else:
|
| 607 |
+
return weight.transpose(-2, -1)
|
| 608 |
+
else: # down
|
| 609 |
+
# For down, we need (E, intermediate, hidden_dim)
|
| 610 |
+
if weight.shape[2] == hidden_dim:
|
| 611 |
+
return weight
|
| 612 |
+
else:
|
| 613 |
+
return weight.transpose(-2, -1)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
# ============================================================================
|
| 617 |
+
# Generic MoE Detection and ParamWrapper Patching
|
| 618 |
+
# ============================================================================
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def _is_moe_experts_module(module) -> bool:
|
| 622 |
+
"""
|
| 623 |
+
Check if module is an MoE experts layer (generic, not model-specific).
|
| 624 |
+
|
| 625 |
+
Detects modules with stacked expert weights as 3D nn.Parameter:
|
| 626 |
+
- gate_up_proj/down_proj pattern (Qwen3-MoE, Qwen3-VL-MoE, etc.)
|
| 627 |
+
- w1/w2/w3 pattern (older MoE models)
|
| 628 |
+
"""
|
| 629 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 630 |
+
|
| 631 |
+
import torch.nn as nn
|
| 632 |
+
|
| 633 |
+
# Check for gate_up_proj pattern
|
| 634 |
+
# After PEFT's nn.utils.parametrize wrapping, accessing gate_up_proj
|
| 635 |
+
# returns torch.Tensor (not nn.Parameter), so we must accept both.
|
| 636 |
+
if hasattr(module, "gate_up_proj"):
|
| 637 |
+
param = module.gate_up_proj
|
| 638 |
+
# 4-bit parameters are packed into 2D tensors (n_params, 1) or similar.
|
| 639 |
+
# Standard MoE weights are 3D (num_experts, in, out).
|
| 640 |
+
if isinstance(param, (nn.Parameter, torch.Tensor)) and param.ndim in (2, 3):
|
| 641 |
+
return True
|
| 642 |
+
|
| 643 |
+
# Check for w1/w2 pattern (separate gate/up projections)
|
| 644 |
+
if hasattr(module, "w1") and hasattr(module, "w2"):
|
| 645 |
+
w1 = module.w1
|
| 646 |
+
if isinstance(w1, (nn.Parameter, torch.Tensor)) and w1.ndim in (2, 3):
|
| 647 |
+
return True
|
| 648 |
+
|
| 649 |
+
return False
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
# Aliases for compatibility with gpt_oss.py
|
| 653 |
+
_get_moe_lora_weights = _extract_lora_from_wrapper
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
# Store original ParamWrapper.forward for fallback
|
| 657 |
+
_original_param_wrapper_forward = None
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def _patched_param_wrapper_forward(
|
| 661 |
+
self, x: torch.Tensor, *args, **kwargs
|
| 662 |
+
) -> torch.Tensor:
|
| 663 |
+
"""
|
| 664 |
+
Patched ParamWrapper.forward for MoE separated LoRA.
|
| 665 |
+
|
| 666 |
+
For MoE expert modules:
|
| 667 |
+
- Bypasses PEFTs _activate_lora parametrization context
|
| 668 |
+
- Stores LoRA data by parameter_name for forward_native_grouped_mm to use
|
| 669 |
+
|
| 670 |
+
For non-MoE modules:
|
| 671 |
+
- Falls back to original PEFT forward
|
| 672 |
+
"""
|
| 673 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 674 |
+
|
| 675 |
+
# CRITICAL: Use self.base_layer for forward call (immediate parent)
|
| 676 |
+
# NOT self.get_base_layer() which recursively traverses to deepest layer!
|
| 677 |
+
# The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts
|
| 678 |
+
immediate_base_layer = self.base_layer
|
| 679 |
+
|
| 680 |
+
# For storing LoRA data, we DO need the actual experts module
|
| 681 |
+
# Use get_base_layer() to find it (recursive traversal is correct here)
|
| 682 |
+
experts_module = self.get_base_layer()
|
| 683 |
+
|
| 684 |
+
use_separated = _should_use_separated_lora()
|
| 685 |
+
param_name = getattr(self, "parameter_name", None)
|
| 686 |
+
|
| 687 |
+
# Check if this is an MoE experts module that should use separated LoRA
|
| 688 |
+
if (
|
| 689 |
+
use_separated
|
| 690 |
+
and param_name in ("gate_up_proj", "down_proj")
|
| 691 |
+
and _is_moe_experts_module(experts_module)
|
| 692 |
+
):
|
| 693 |
+
# MoE experts: bypass PEFT's _activate_lora, use separated computation
|
| 694 |
+
|
| 695 |
+
# Check adapter state
|
| 696 |
+
if self.disable_adapters:
|
| 697 |
+
if self.merged:
|
| 698 |
+
self.unmerge()
|
| 699 |
+
return immediate_base_layer(x, *args, **kwargs)
|
| 700 |
+
|
| 701 |
+
if self.merged:
|
| 702 |
+
return immediate_base_layer(x, *args, **kwargs)
|
| 703 |
+
|
| 704 |
+
# Ensure wrapper.num_experts is set for LoRA weight reshaping
|
| 705 |
+
if not hasattr(self, "num_experts"):
|
| 706 |
+
if hasattr(experts_module, "num_experts"):
|
| 707 |
+
self.num_experts = experts_module.num_experts
|
| 708 |
+
elif hasattr(experts_module, param_name):
|
| 709 |
+
p = getattr(experts_module, param_name)
|
| 710 |
+
if hasattr(p, "shape") and len(p.shape) >= 1:
|
| 711 |
+
self.num_experts = p.shape[0]
|
| 712 |
+
|
| 713 |
+
# Extract LoRA for this specific parameter
|
| 714 |
+
lora_data = _extract_lora_from_wrapper(self)
|
| 715 |
+
|
| 716 |
+
if lora_data is not None and param_name:
|
| 717 |
+
# Store LoRA data on the EXPERTS MODULE (not base_layer)
|
| 718 |
+
# e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj
|
| 719 |
+
lora_attr = f"_unsloth_lora_{param_name}"
|
| 720 |
+
setattr(experts_module, lora_attr, lora_data)
|
| 721 |
+
|
| 722 |
+
try:
|
| 723 |
+
# Call IMMEDIATE base_layer to preserve wrapper chain
|
| 724 |
+
# (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts)
|
| 725 |
+
result = immediate_base_layer(x, *args, **kwargs)
|
| 726 |
+
finally:
|
| 727 |
+
# Clean up
|
| 728 |
+
if param_name:
|
| 729 |
+
lora_attr = f"_unsloth_lora_{param_name}"
|
| 730 |
+
if hasattr(experts_module, lora_attr):
|
| 731 |
+
delattr(experts_module, lora_attr)
|
| 732 |
+
|
| 733 |
+
return result
|
| 734 |
+
|
| 735 |
+
# Non-MoE: use original PEFT forward with _activate_lora
|
| 736 |
+
return _original_param_wrapper_forward(self, x, *args, **kwargs)
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def patch_param_wrapper_for_moe():
|
| 740 |
+
"""
|
| 741 |
+
Patch PEFT's ParamWrapper.forward to use separated LoRA for MoE.
|
| 742 |
+
|
| 743 |
+
This should be called after PEFT is imported.
|
| 744 |
+
"""
|
| 745 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 746 |
+
|
| 747 |
+
global _original_param_wrapper_forward
|
| 748 |
+
|
| 749 |
+
module = _load_cached_moe_utils_module()
|
| 750 |
+
if module is not None and hasattr(module, "patch_param_wrapper_for_moe"):
|
| 751 |
+
try:
|
| 752 |
+
return module.patch_param_wrapper_for_moe()
|
| 753 |
+
except Exception:
|
| 754 |
+
pass
|
| 755 |
+
|
| 756 |
+
try:
|
| 757 |
+
from peft.tuners.lora.layer import ParamWrapper
|
| 758 |
+
|
| 759 |
+
# Store original forward
|
| 760 |
+
if _original_param_wrapper_forward is None:
|
| 761 |
+
_original_param_wrapper_forward = ParamWrapper.forward
|
| 762 |
+
|
| 763 |
+
# Patch with our version
|
| 764 |
+
ParamWrapper.forward = _patched_param_wrapper_forward
|
| 765 |
+
|
| 766 |
+
return True
|
| 767 |
+
except ImportError:
|
| 768 |
+
return False
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def forward_native_grouped_mm(
|
| 772 |
+
self,
|
| 773 |
+
hidden_states: torch.Tensor,
|
| 774 |
+
top_k_index: torch.Tensor,
|
| 775 |
+
top_k_weights: torch.Tensor,
|
| 776 |
+
) -> torch.Tensor:
|
| 777 |
+
"""
|
| 778 |
+
Native Pytorch grouped GEMM MoE forward pass.
|
| 779 |
+
Uses torch._grouped_mm which is significantly faster than loop and works without Triton dependencies.
|
| 780 |
+
Requires torch._grouped_mm support (verified via runtime check).
|
| 781 |
+
"""
|
| 782 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 783 |
+
|
| 784 |
+
# Runtime safety check - defense in depth
|
| 785 |
+
if not _check_torch_grouped_mm_supported():
|
| 786 |
+
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
| 787 |
+
raise RuntimeError(
|
| 788 |
+
f"torch._grouped_mm is not supported on this device (Compute Capability {major}.{minor}). "
|
| 789 |
+
f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend."
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
is_2d_input = hidden_states.dim() == 2
|
| 793 |
+
if is_2d_input:
|
| 794 |
+
sequence_length, hidden_dim = hidden_states.shape
|
| 795 |
+
batch_size = 1
|
| 796 |
+
else:
|
| 797 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 798 |
+
|
| 799 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 800 |
+
|
| 801 |
+
# 1. Calculate routing
|
| 802 |
+
flat_top_k = top_k_index.view(-1)
|
| 803 |
+
num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int()
|
| 804 |
+
|
| 805 |
+
# 2. Sort indices to group tokens by expert
|
| 806 |
+
sorted_indices = torch.argsort(flat_top_k, stable=True)
|
| 807 |
+
token_indices = sorted_indices // top_k_index.shape[-1]
|
| 808 |
+
|
| 809 |
+
# 3. Permute Input
|
| 810 |
+
# We need to gather inputs. Since we may have expanded top_k, we use token_indices to map back to original input
|
| 811 |
+
permuted_input = hidden_states[token_indices]
|
| 812 |
+
|
| 813 |
+
# 4. Prepare Grouped MM arguments
|
| 814 |
+
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
|
| 815 |
+
|
| 816 |
+
# ========================================================================
|
| 817 |
+
# Gate + Up projection with optional separated LoRA (DEFAULT)
|
| 818 |
+
# ========================================================================
|
| 819 |
+
use_separated_lora = _should_use_separated_lora()
|
| 820 |
+
gate_up_lora = None
|
| 821 |
+
|
| 822 |
+
# Check for injected LoRA data from patched ParamWrapper (preferred path)
|
| 823 |
+
if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None:
|
| 824 |
+
gate_up_lora = self._unsloth_lora_gate_up_proj[
|
| 825 |
+
:3
|
| 826 |
+
] # (first_weight, second_weight, scaling)
|
| 827 |
+
# Fallback: check parameter directly (for older wrapping patterns)
|
| 828 |
+
elif (
|
| 829 |
+
use_separated_lora
|
| 830 |
+
and hasattr(self, "gate_up_proj")
|
| 831 |
+
and _has_lora_adapters(self.gate_up_proj)
|
| 832 |
+
):
|
| 833 |
+
gate_up_lora = _extract_lora_weights(
|
| 834 |
+
self.gate_up_proj, num_experts=self.num_experts, experts_module=self
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
if hasattr(self, "gate_up_proj"):
|
| 838 |
+
# Get base weights (raw, without LoRA)
|
| 839 |
+
gate_up_base = _get_base_weight(self.gate_up_proj)
|
| 840 |
+
|
| 841 |
+
# Get model type for preprocessing (if registered)
|
| 842 |
+
model_type = getattr(self, "_unsloth_model_type", None)
|
| 843 |
+
|
| 844 |
+
# Handle different weight shapes using preprocessor
|
| 845 |
+
# torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view.
|
| 846 |
+
w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type)
|
| 847 |
+
# Base forward: X @ W
|
| 848 |
+
mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
|
| 849 |
+
|
| 850 |
+
# Add separated LoRA contribution: + ((X @ first) @ second) * scaling
|
| 851 |
+
# _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
|
| 852 |
+
if gate_up_lora is not None:
|
| 853 |
+
first_weight, second_weight, scaling = gate_up_lora
|
| 854 |
+
|
| 855 |
+
# Cast to input dtype (LoRA weights are float32, input may be bfloat16)
|
| 856 |
+
# Ensure contiguous for grouped_mm alignment requirements
|
| 857 |
+
first_weight = first_weight.to(permuted_input.dtype).contiguous()
|
| 858 |
+
second_weight = second_weight.to(permuted_input.dtype).contiguous()
|
| 859 |
+
|
| 860 |
+
# Step 1: permuted_input @ first_weight
|
| 861 |
+
try:
|
| 862 |
+
lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets)
|
| 863 |
+
lora_out = lora_out.contiguous()
|
| 864 |
+
except RuntimeError as e:
|
| 865 |
+
raise e
|
| 866 |
+
|
| 867 |
+
# Step 2: result @ second_weight
|
| 868 |
+
# Handle unaligned O dimension or other grouped_mm failures
|
| 869 |
+
try:
|
| 870 |
+
if second_weight.shape[-1] % 8 != 0:
|
| 871 |
+
pad_size = 8 - (second_weight.shape[-1] % 8)
|
| 872 |
+
second_weight_padded = F.pad(
|
| 873 |
+
second_weight, (0, pad_size)
|
| 874 |
+
).contiguous()
|
| 875 |
+
lora_delta = _grouped_mm_with_backward_fix(
|
| 876 |
+
lora_out, second_weight_padded, offsets
|
| 877 |
+
)
|
| 878 |
+
lora_delta = lora_delta[:, :-pad_size]
|
| 879 |
+
else:
|
| 880 |
+
lora_delta = _grouped_mm_with_backward_fix(
|
| 881 |
+
lora_out, second_weight, offsets
|
| 882 |
+
)
|
| 883 |
+
except RuntimeError:
|
| 884 |
+
# Fallback to manual loop if grouped_mm fails (e.g. stride alignment)
|
| 885 |
+
lora_delta = torch.empty(
|
| 886 |
+
(lora_out.shape[0], second_weight.shape[-1]),
|
| 887 |
+
dtype=lora_out.dtype,
|
| 888 |
+
device=lora_out.device,
|
| 889 |
+
)
|
| 890 |
+
cpu_offsets = offsets.cpu().tolist()
|
| 891 |
+
prev_offset = 0
|
| 892 |
+
for i, end in enumerate(cpu_offsets):
|
| 893 |
+
if prev_offset < end:
|
| 894 |
+
lora_delta[prev_offset:end] = torch.matmul(
|
| 895 |
+
lora_out[prev_offset:end], second_weight[i]
|
| 896 |
+
)
|
| 897 |
+
prev_offset = end
|
| 898 |
+
|
| 899 |
+
# Add scaled LoRA contribution
|
| 900 |
+
mm1_out = mm1_out + lora_delta * scaling
|
| 901 |
+
|
| 902 |
+
if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
|
| 903 |
+
num_repeats = num_tokens_per_expert.to(self.gate_up_proj_bias.device)
|
| 904 |
+
bias_expanded = self.gate_up_proj_bias.repeat_interleave(num_repeats, dim=0)
|
| 905 |
+
mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype)
|
| 906 |
+
|
| 907 |
+
if "GptOssExperts" in self.__class__.__name__:
|
| 908 |
+
gate = mm1_out[..., ::2]
|
| 909 |
+
up = mm1_out[..., 1::2]
|
| 910 |
+
else:
|
| 911 |
+
gate, up = mm1_out.chunk(2, dim=-1)
|
| 912 |
+
|
| 913 |
+
elif hasattr(self, "w1") and hasattr(self, "w3"):
|
| 914 |
+
# Separate w1/w3 weights (older models)
|
| 915 |
+
w1_base = _get_base_weight(self.w1)
|
| 916 |
+
w3_base = _get_base_weight(self.w3)
|
| 917 |
+
|
| 918 |
+
w1 = w1_base.transpose(-2, -1)
|
| 919 |
+
w3 = w3_base.transpose(-2, -1)
|
| 920 |
+
|
| 921 |
+
gate = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
|
| 922 |
+
up = _grouped_mm_with_backward_fix(permuted_input, w3, offsets)
|
| 923 |
+
|
| 924 |
+
# Add LoRA for w1 and w3 separately if present
|
| 925 |
+
if use_separated_lora:
|
| 926 |
+
if _has_lora_adapters(self.w1):
|
| 927 |
+
w1_lora = _extract_lora_weights(self.w1, experts_module=self)
|
| 928 |
+
if w1_lora is not None:
|
| 929 |
+
lora_A, lora_B, scaling = w1_lora
|
| 930 |
+
lora_A_t = lora_A.transpose(-2, -1)
|
| 931 |
+
lora_A_out = _grouped_mm_with_backward_fix(
|
| 932 |
+
permuted_input, lora_A_t, offsets
|
| 933 |
+
)
|
| 934 |
+
lora_B_t = lora_B.transpose(-2, -1)
|
| 935 |
+
lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
|
| 936 |
+
gate = gate + lora_B_out * scaling
|
| 937 |
+
|
| 938 |
+
if _has_lora_adapters(self.w3):
|
| 939 |
+
w3_lora = _extract_lora_weights(self.w3, experts_module=self)
|
| 940 |
+
if w3_lora is not None:
|
| 941 |
+
lora_A, lora_B, scaling = w3_lora
|
| 942 |
+
lora_A_t = lora_A.transpose(-2, -1)
|
| 943 |
+
lora_A_out = _grouped_mm_with_backward_fix(
|
| 944 |
+
permuted_input, lora_A_t, offsets
|
| 945 |
+
)
|
| 946 |
+
lora_B_t = lora_B.transpose(-2, -1)
|
| 947 |
+
lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
|
| 948 |
+
up = up + lora_B_out * scaling
|
| 949 |
+
else:
|
| 950 |
+
raise AttributeError("MoE layer must have 'gate_up_proj' or 'w1'/'w3'.")
|
| 951 |
+
|
| 952 |
+
# Activation
|
| 953 |
+
if "GptOssExperts" in self.__class__.__name__:
|
| 954 |
+
# Custom activation from GptOss
|
| 955 |
+
limit = getattr(self, "limit", 7.0)
|
| 956 |
+
alpha = getattr(self, "alpha", 1.702)
|
| 957 |
+
|
| 958 |
+
gate = gate.clamp(min=None, max=limit)
|
| 959 |
+
up = up.clamp(min=-limit, max=limit)
|
| 960 |
+
glu = gate * torch.sigmoid(gate * alpha)
|
| 961 |
+
inter = (up + 1.0) * glu
|
| 962 |
+
else:
|
| 963 |
+
inter = F.silu(gate) * up
|
| 964 |
+
|
| 965 |
+
# ========================================================================
|
| 966 |
+
# Down projection with optional separated LoRA (DEFAULT)
|
| 967 |
+
# ========================================================================
|
| 968 |
+
down_lora = None
|
| 969 |
+
|
| 970 |
+
# Check for injected LoRA data from patched ParamWrapper (preferred path)
|
| 971 |
+
if getattr(self, "_unsloth_lora_down_proj", None) is not None:
|
| 972 |
+
down_lora = self._unsloth_lora_down_proj[
|
| 973 |
+
:3
|
| 974 |
+
] # (first_weight, second_weight, scaling)
|
| 975 |
+
# Fallback: check parameter directly (for older wrapping patterns)
|
| 976 |
+
elif (
|
| 977 |
+
use_separated_lora
|
| 978 |
+
and hasattr(self, "down_proj")
|
| 979 |
+
and _has_lora_adapters(self.down_proj)
|
| 980 |
+
):
|
| 981 |
+
down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts, experts_module=self)
|
| 982 |
+
|
| 983 |
+
if hasattr(self, "down_proj"):
|
| 984 |
+
# Get base weights
|
| 985 |
+
down_base = _get_base_weight(self.down_proj)
|
| 986 |
+
|
| 987 |
+
# Get model type for preprocessing (if registered)
|
| 988 |
+
model_type = getattr(self, "_unsloth_model_type", None)
|
| 989 |
+
|
| 990 |
+
# Handle different weight shapes using preprocessor
|
| 991 |
+
w2 = preprocess_weight(down_base, "down", hidden_dim, model_type)
|
| 992 |
+
|
| 993 |
+
# Base forward
|
| 994 |
+
mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
|
| 995 |
+
|
| 996 |
+
# Add separated LoRA contribution if present
|
| 997 |
+
# _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
|
| 998 |
+
if down_lora is not None:
|
| 999 |
+
first_weight, second_weight, scaling = down_lora
|
| 1000 |
+
|
| 1001 |
+
# Cast to input dtype (LoRA weights are float32, input may be bfloat16)
|
| 1002 |
+
first_weight = first_weight.to(inter.dtype).contiguous()
|
| 1003 |
+
second_weight = second_weight.to(inter.dtype).contiguous()
|
| 1004 |
+
|
| 1005 |
+
# Step 1: inter @ first_weight
|
| 1006 |
+
lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets)
|
| 1007 |
+
lora_out = lora_out.contiguous()
|
| 1008 |
+
|
| 1009 |
+
# Step 2: result @ second_weight
|
| 1010 |
+
try:
|
| 1011 |
+
lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
|
| 1012 |
+
except RuntimeError:
|
| 1013 |
+
# Fallback to manual loop
|
| 1014 |
+
lora_delta = torch.empty(
|
| 1015 |
+
(lora_out.shape[0], second_weight.shape[-1]),
|
| 1016 |
+
dtype=lora_out.dtype,
|
| 1017 |
+
device=lora_out.device,
|
| 1018 |
+
)
|
| 1019 |
+
cpu_offsets = offsets.cpu().tolist()
|
| 1020 |
+
prev_offset = 0
|
| 1021 |
+
for i, end in enumerate(cpu_offsets):
|
| 1022 |
+
if prev_offset < end:
|
| 1023 |
+
lora_delta[prev_offset:end] = torch.matmul(
|
| 1024 |
+
lora_out[prev_offset:end], second_weight[i]
|
| 1025 |
+
)
|
| 1026 |
+
prev_offset = end
|
| 1027 |
+
|
| 1028 |
+
# Add scaled LoRA contribution
|
| 1029 |
+
mm2_out = mm2_out + lora_delta * scaling
|
| 1030 |
+
|
| 1031 |
+
if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
|
| 1032 |
+
bias_expanded = self.down_proj_bias.repeat_interleave(
|
| 1033 |
+
num_tokens_per_expert.to(self.down_proj_bias.device), dim=0
|
| 1034 |
+
).to(mm2_out.device)
|
| 1035 |
+
mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype)
|
| 1036 |
+
|
| 1037 |
+
elif hasattr(self, "w2"):
|
| 1038 |
+
w2_base = _get_base_weight(self.w2)
|
| 1039 |
+
w2 = w2_base.transpose(-2, -1)
|
| 1040 |
+
|
| 1041 |
+
# Base forward
|
| 1042 |
+
mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
|
| 1043 |
+
|
| 1044 |
+
# Add LoRA if present
|
| 1045 |
+
if use_separated_lora and _has_lora_adapters(self.w2):
|
| 1046 |
+
w2_lora = _extract_lora_weights(self.w2, experts_module=self)
|
| 1047 |
+
if w2_lora is not None:
|
| 1048 |
+
lora_A, lora_B, scaling = w2_lora
|
| 1049 |
+
lora_A_t = lora_A.transpose(-2, -1).contiguous()
|
| 1050 |
+
lora_A_out = _grouped_mm_with_backward_fix(inter, lora_A_t, offsets)
|
| 1051 |
+
lora_B_t = lora_B.transpose(-2, -1).contiguous()
|
| 1052 |
+
lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
|
| 1053 |
+
mm2_out = mm2_out + lora_B_out * scaling
|
| 1054 |
+
else:
|
| 1055 |
+
raise AttributeError("MoE layer must have 'down_proj' or 'w2'.")
|
| 1056 |
+
|
| 1057 |
+
# 5. Apply Routing Weights and Scatter Add (Reduce)
|
| 1058 |
+
flat_weights = top_k_weights.view(-1)
|
| 1059 |
+
permuted_weights = flat_weights[sorted_indices]
|
| 1060 |
+
mm2_out = mm2_out * permuted_weights.unsqueeze(-1)
|
| 1061 |
+
|
| 1062 |
+
final_hidden_states = torch.zeros(
|
| 1063 |
+
(batch_size * sequence_length, hidden_dim),
|
| 1064 |
+
dtype=hidden_states.dtype,
|
| 1065 |
+
device=hidden_states.device,
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
final_hidden_states.index_add_(0, token_indices, mm2_out.to(hidden_states.dtype))
|
| 1069 |
+
|
| 1070 |
+
if is_2d_input:
|
| 1071 |
+
return final_hidden_states
|
| 1072 |
+
|
| 1073 |
+
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
def forward_triton_grouped_gemm(
|
| 1077 |
+
self,
|
| 1078 |
+
hidden_states: torch.Tensor,
|
| 1079 |
+
top_k_index: torch.Tensor,
|
| 1080 |
+
top_k_weights: torch.Tensor,
|
| 1081 |
+
) -> torch.Tensor:
|
| 1082 |
+
"""
|
| 1083 |
+
Grouped GEMM MoE forward pass using Triton kernels.
|
| 1084 |
+
Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin).
|
| 1085 |
+
"""
|
| 1086 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 1087 |
+
|
| 1088 |
+
# Import grouped GEMM interface
|
| 1089 |
+
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
|
| 1090 |
+
|
| 1091 |
+
# Import autotune cache
|
| 1092 |
+
from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels
|
| 1093 |
+
|
| 1094 |
+
# Helper to check TMA support - assumes helper function or just check directly
|
| 1095 |
+
# In original: it was a cached closure. Here we can use _supports_tma() directly
|
| 1096 |
+
|
| 1097 |
+
# nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this!
|
| 1098 |
+
# For now, let's attach it to self if possible, or use a global usage
|
| 1099 |
+
# Attaching to self is cleaner: self._unsloth_moe_configs
|
| 1100 |
+
|
| 1101 |
+
# Create expert mask and find which experts have tokens
|
| 1102 |
+
|
| 1103 |
+
if not hasattr(self, "_unsloth_moe_configs"):
|
| 1104 |
+
self._unsloth_moe_configs = None
|
| 1105 |
+
|
| 1106 |
+
use_separated_lora = _should_use_separated_lora()
|
| 1107 |
+
|
| 1108 |
+
|
| 1109 |
+
# Handle 3D inputs (batch_size, seq_len, hidden_dim)
|
| 1110 |
+
is_3d = hidden_states.dim() == 3
|
| 1111 |
+
if is_3d:
|
| 1112 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 1113 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 1114 |
+
num_tokens = batch_size * seq_len
|
| 1115 |
+
# Also flatten top_k inputs if they are 3D
|
| 1116 |
+
if top_k_index.dim() == 3:
|
| 1117 |
+
top_k_index = top_k_index.view(-1, top_k_index.shape[-1])
|
| 1118 |
+
if top_k_weights.dim() == 3:
|
| 1119 |
+
top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1])
|
| 1120 |
+
else:
|
| 1121 |
+
num_tokens, hidden_dim = hidden_states.shape
|
| 1122 |
+
|
| 1123 |
+
top_k = top_k_index.shape[1]
|
| 1124 |
+
|
| 1125 |
+
# Cache model dimensions and kernel configs on first call
|
| 1126 |
+
if self._unsloth_moe_configs is None:
|
| 1127 |
+
intermediate_dim = self.gate_up_proj.shape[1] // 2
|
| 1128 |
+
|
| 1129 |
+
# Autotune first GEMM
|
| 1130 |
+
gemm1_configs = get_or_autotune_moe_kernels(
|
| 1131 |
+
num_experts=self.num_experts,
|
| 1132 |
+
hidden_dim=hidden_dim,
|
| 1133 |
+
intermediate_dim=intermediate_dim * 2,
|
| 1134 |
+
top_k=top_k,
|
| 1135 |
+
dtype=hidden_states.dtype,
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
# Autotune second GEMM
|
| 1139 |
+
gemm2_configs = get_or_autotune_moe_kernels(
|
| 1140 |
+
num_experts=self.num_experts,
|
| 1141 |
+
hidden_dim=intermediate_dim,
|
| 1142 |
+
intermediate_dim=hidden_dim, # Output dim for 2nd GEMM is hidden_dim
|
| 1143 |
+
top_k=top_k,
|
| 1144 |
+
dtype=hidden_states.dtype,
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
self._unsloth_moe_configs = (intermediate_dim, gemm1_configs, gemm2_configs)
|
| 1148 |
+
|
| 1149 |
+
# Clear autotuning memory overhead
|
| 1150 |
+
torch.cuda.empty_cache()
|
| 1151 |
+
|
| 1152 |
+
# Unpack cached configs
|
| 1153 |
+
intermediate_dim, gemm1_configs, gemm2_configs = self._unsloth_moe_configs
|
| 1154 |
+
|
| 1155 |
+
# Unpack specific kernel configs
|
| 1156 |
+
fwd_config_1, bwd_dX_config_1, bwd_dW_config_1 = gemm1_configs
|
| 1157 |
+
fwd_config_2, bwd_dX_config_2, bwd_dW_config_2 = gemm2_configs
|
| 1158 |
+
|
| 1159 |
+
# Compute routing indices for grouped GEMM
|
| 1160 |
+
token_counts_by_expert, gather_indices = _get_routing_indices(
|
| 1161 |
+
top_k_index, self.num_experts
|
| 1162 |
+
)
|
| 1163 |
+
offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32)
|
| 1164 |
+
|
| 1165 |
+
if self.gate_up_proj.shape[-1] == hidden_dim:
|
| 1166 |
+
w1 = self.gate_up_proj
|
| 1167 |
+
else:
|
| 1168 |
+
w1 = self.gate_up_proj.transpose(-2, -1).contiguous()
|
| 1169 |
+
|
| 1170 |
+
# First grouped GEMM: gate_up projection
|
| 1171 |
+
first_gemm_output = grouped_gemm(
|
| 1172 |
+
X=hidden_states,
|
| 1173 |
+
W=w1,
|
| 1174 |
+
m_sizes=token_counts_by_expert,
|
| 1175 |
+
topk=top_k,
|
| 1176 |
+
gather_indices=gather_indices,
|
| 1177 |
+
permute_x=True,
|
| 1178 |
+
permute_y=False,
|
| 1179 |
+
autotune=False, # We use cached configs
|
| 1180 |
+
kernel_config_fwd=fwd_config_1,
|
| 1181 |
+
kernel_config_bwd_dX=bwd_dX_config_1,
|
| 1182 |
+
kernel_config_bwd_dW=bwd_dW_config_1,
|
| 1183 |
+
is_first_gemm=True,
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
# Apply SiLU activation and multiply gate with up
|
| 1187 |
+
intermediate = _silu_and_mul(first_gemm_output)
|
| 1188 |
+
|
| 1189 |
+
# Grouped GEMM 2: down projection
|
| 1190 |
+
|
| 1191 |
+
# Grouped GEMM 2: down projection
|
| 1192 |
+
# Prepare LoRA data
|
| 1193 |
+
down_lora = None
|
| 1194 |
+
if getattr(self, "_unsloth_lora_down_proj", None) is not None:
|
| 1195 |
+
down_lora = self._unsloth_lora_down_proj[:3]
|
| 1196 |
+
elif (
|
| 1197 |
+
use_separated_lora
|
| 1198 |
+
and hasattr(self, "down_proj")
|
| 1199 |
+
and _has_lora_adapters(self.down_proj)
|
| 1200 |
+
):
|
| 1201 |
+
down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts)
|
| 1202 |
+
|
| 1203 |
+
if self.down_proj.shape[-1] == intermediate.shape[-1]:
|
| 1204 |
+
w2 = self.down_proj
|
| 1205 |
+
else:
|
| 1206 |
+
w2 = self.down_proj.transpose(-2, -1).contiguous()
|
| 1207 |
+
|
| 1208 |
+
second_gemm_output = grouped_gemm(
|
| 1209 |
+
X=intermediate,
|
| 1210 |
+
W=w2,
|
| 1211 |
+
m_sizes=token_counts_by_expert,
|
| 1212 |
+
topk=top_k,
|
| 1213 |
+
gather_indices=gather_indices,
|
| 1214 |
+
permute_x=False,
|
| 1215 |
+
permute_y=True,
|
| 1216 |
+
autotune=False, # We use cached configs
|
| 1217 |
+
kernel_config_fwd=fwd_config_2,
|
| 1218 |
+
kernel_config_bwd_dX=bwd_dX_config_2,
|
| 1219 |
+
kernel_config_bwd_dW=bwd_dW_config_2,
|
| 1220 |
+
is_first_gemm=False,
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
# Add separated LoRA contribution for Down
|
| 1224 |
+
if down_lora is not None:
|
| 1225 |
+
first_weight, second_weight, scaling = down_lora
|
| 1226 |
+
|
| 1227 |
+
# Intermediate is already permuted from step 1.
|
| 1228 |
+
# Offsets are same.
|
| 1229 |
+
|
| 1230 |
+
first_weight = first_weight.to(intermediate.dtype)
|
| 1231 |
+
second_weight = second_weight.to(intermediate.dtype)
|
| 1232 |
+
|
| 1233 |
+
lora_delta = _apply_lora_grouped_mm(
|
| 1234 |
+
intermediate,
|
| 1235 |
+
first_weight,
|
| 1236 |
+
second_weight,
|
| 1237 |
+
offsets,
|
| 1238 |
+
scaling,
|
| 1239 |
+
grouped_mm_func=native_moe_grouped_mm
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
second_gemm_output = second_gemm_output + lora_delta
|
| 1243 |
+
|
| 1244 |
+
# Apply routing weights and sum across top_k experts
|
| 1245 |
+
# Output shape: (num_tokens, top_k, hidden_dim) -> (num_tokens, hidden_dim)
|
| 1246 |
+
# Ensure top_k_weights matches dtype (can be float32 from softmax)
|
| 1247 |
+
top_k_weights_casted = top_k_weights.to(hidden_states.dtype)
|
| 1248 |
+
final_hidden_states = (
|
| 1249 |
+
second_gemm_output.view(num_tokens, top_k, hidden_dim)
|
| 1250 |
+
* top_k_weights_casted[..., None]
|
| 1251 |
+
)
|
| 1252 |
+
final_hidden_states = final_hidden_states.sum(dim=1)
|
| 1253 |
+
|
| 1254 |
+
if is_3d:
|
| 1255 |
+
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
|
| 1256 |
+
|
| 1257 |
+
return final_hidden_states
|
| 1258 |
+
|
| 1259 |
+
|
| 1260 |
+
@torch.compiler.disable
|
| 1261 |
+
def forward_native_moe_loop(
|
| 1262 |
+
self,
|
| 1263 |
+
hidden_states: torch.Tensor,
|
| 1264 |
+
top_k_index: torch.Tensor,
|
| 1265 |
+
top_k_weights: torch.Tensor,
|
| 1266 |
+
) -> torch.Tensor:
|
| 1267 |
+
"""
|
| 1268 |
+
Loop-based MoE forward pass. Loops over experts that have tokens routed to them.
|
| 1269 |
+
Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow.
|
| 1270 |
+
"""
|
| 1271 |
+
# This Unsloth Zoo code section is licensed under AGPL3
|
| 1272 |
+
final_hidden_states = torch.zeros_like(hidden_states)
|
| 1273 |
+
|
| 1274 |
+
# Create expert mask and find which experts have tokens
|
| 1275 |
+
with torch.no_grad():
|
| 1276 |
+
expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts)
|
| 1277 |
+
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, n_tokens)
|
| 1278 |
+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 1279 |
+
|
| 1280 |
+
# Only loop over experts that actually have tokens routed to them
|
| 1281 |
+
for expert_idx_t in expert_hit:
|
| 1282 |
+
expert_idx = expert_idx_t.item()
|
| 1283 |
+
|
| 1284 |
+
# Find which tokens are routed to this expert
|
| 1285 |
+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
| 1286 |
+
|
| 1287 |
+
# Gather only the tokens for this expert
|
| 1288 |
+
current_state = hidden_states[token_idx]
|
| 1289 |
+
|
| 1290 |
+
# Compute gate_up projection for this expert only
|
| 1291 |
+
# Handle 'gate_up_proj' or 'w1'/'w3'
|
| 1292 |
+
if hasattr(self, "gate_up_proj"):
|
| 1293 |
+
gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(
|
| 1294 |
+
2, dim=-1
|
| 1295 |
+
)
|
| 1296 |
+
else:
|
| 1297 |
+
gate = F.linear(current_state, self.w1[expert_idx])
|
| 1298 |
+
up = F.linear(current_state, self.w3[expert_idx])
|
| 1299 |
+
|
| 1300 |
+
current_hidden_states = self.act_fn(gate) * up
|
| 1301 |
+
|
| 1302 |
+
# Compute down projection for this expert only
|
| 1303 |
+
if hasattr(self, "down_proj"):
|
| 1304 |
+
current_hidden_states = F.linear(
|
| 1305 |
+
current_hidden_states, self.down_proj[expert_idx]
|
| 1306 |
+
)
|
| 1307 |
+
else:
|
| 1308 |
+
current_hidden_states = F.linear(current_hidden_states, self.w2[expert_idx])
|
| 1309 |
+
|
| 1310 |
+
# Apply routing weights
|
| 1311 |
+
current_hidden_states = (
|
| 1312 |
+
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
|
| 1313 |
+
)
|
| 1314 |
+
|
| 1315 |
+
# Scatter back to final output
|
| 1316 |
+
final_hidden_states.index_add_(
|
| 1317 |
+
0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
|
| 1318 |
+
)
|
| 1319 |
+
|
| 1320 |
+
return final_hidden_states
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|