Add Training Scripts
Browse files- trainer-kit/.gitignore +55 -0
- trainer-kit/CPT-14b/README.md +189 -0
- trainer-kit/CPT-14b/README_instruct.md +168 -0
- trainer-kit/CPT-14b/commands.md +15 -0
- trainer-kit/CPT-14b/config.yaml +91 -0
- trainer-kit/CPT-14b/dummy_data.jsonl +6 -0
- trainer-kit/CPT-14b/requirements.txt +25 -0
- trainer-kit/CPT-14b/run_cpt.py +772 -0
- trainer-kit/CPT/README.md +189 -0
- trainer-kit/CPT/commands.md +15 -0
- trainer-kit/CPT/config.yaml +82 -0
- trainer-kit/CPT/detailed_parameter_documentation.md +795 -0
- trainer-kit/CPT/dummy_data.jsonl +6 -0
- trainer-kit/CPT/requirements.txt +22 -0
- trainer-kit/CPT/run_cpt.py +708 -0
- trainer-kit/SFT-14b/.DS_Store +0 -0
- trainer-kit/SFT-14b/config_instruct.yaml +144 -0
- trainer-kit/SFT-14b/instruct_data.jsonl +4 -0
- trainer-kit/SFT-14b/requirements.txt +23 -0
- trainer-kit/SFT-14b/run_instruct.py +844 -0
- trainer-kit/SFT/.DS_Store +0 -0
- trainer-kit/SFT/config_instruct.yaml +144 -0
- trainer-kit/SFT/instruct_data.jsonl +4 -0
- trainer-kit/SFT/requirements.txt +23 -0
- trainer-kit/SFT/run_instruct.py +921 -0
- trainer-kit/documentation.md +296 -0
trainer-kit/.gitignore
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
ENV/
|
| 10 |
+
env.bak/
|
| 11 |
+
venv.bak/
|
| 12 |
+
pythonenv*
|
| 13 |
+
.pytest_cache/
|
| 14 |
+
ipynb_checkpoints/
|
| 15 |
+
|
| 16 |
+
# Virtualenv
|
| 17 |
+
.venv
|
| 18 |
+
venv/
|
| 19 |
+
virtualenv/
|
| 20 |
+
env/
|
| 21 |
+
|
| 22 |
+
# IDE
|
| 23 |
+
.vscode/
|
| 24 |
+
.idea/
|
| 25 |
+
*.sublime-workspace
|
| 26 |
+
*.sublime-project
|
| 27 |
+
*.swp
|
| 28 |
+
*.swo
|
| 29 |
+
|
| 30 |
+
# Build
|
| 31 |
+
build/
|
| 32 |
+
dist/
|
| 33 |
+
*.egg-info/
|
| 34 |
+
|
| 35 |
+
# Data and logs
|
| 36 |
+
data/
|
| 37 |
+
logs/
|
| 38 |
+
*.log
|
| 39 |
+
runs/**
|
| 40 |
+
output/**
|
| 41 |
+
|
| 42 |
+
# Jupyter
|
| 43 |
+
.ipynb_checkpoints/
|
| 44 |
+
|
| 45 |
+
# Environment
|
| 46 |
+
.env
|
| 47 |
+
.ENV
|
| 48 |
+
.env.bak
|
| 49 |
+
.venv
|
| 50 |
+
venv
|
| 51 |
+
venv.bak
|
| 52 |
+
|
| 53 |
+
# OS generated files
|
| 54 |
+
.DS_Store
|
| 55 |
+
Thumbs.db
|
trainer-kit/CPT-14b/README.md
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Trainer‑Kit : Config‑Driven CPT (LoRA / QLoRA) with Packing, Logging, Resume, and Merge
|
| 2 |
+
|
| 3 |
+
Trainer‑Kit is a small, config‑driven training runner for **continued pretraining (CPT)** on causal LMs.
|
| 4 |
+
It supports **LoRA** and **QLoRA**, data **packing** (strict or padding‑masked), **checkpointing + resume**, **JSONL logging**, periodic **eval with perplexity**, and an optional **merge** step to export a final merged model.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## What we built
|
| 9 |
+
|
| 10 |
+
### ✅ Core goals implemented
|
| 11 |
+
|
| 12 |
+
* **CPT training loop** controlled entirely via a **YAML config**
|
| 13 |
+
* **Local model support** (load from filesystem) and optional **HF download** (if `repo_id` is a hub id)
|
| 14 |
+
* **JSONL datasets** for train (+ optional eval split)
|
| 15 |
+
* **CPT‑style token stream packing** into fixed‑length blocks
|
| 16 |
+
* **Two packing modes**
|
| 17 |
+
|
| 18 |
+
* `drop`: strict CPT, drop remainder tokens (preferred for real CPT)
|
| 19 |
+
* `pad`: pad the remainder to `block_size` and **mask loss** on padding (useful for small datasets / debugging)
|
| 20 |
+
* **Checkpointing + resume**
|
| 21 |
+
|
| 22 |
+
* `resume_from_checkpoint: "auto"` resumes from the latest checkpoint under `run_dir/checkpoints`
|
| 23 |
+
* **JSONL logs** written locally
|
| 24 |
+
|
| 25 |
+
* training logs: `run_dir/logs/train.jsonl`
|
| 26 |
+
* eval logs: `run_dir/logs/eval.jsonl`
|
| 27 |
+
* **Evaluation**
|
| 28 |
+
|
| 29 |
+
* logs `eval_loss` and computed `perplexity = exp(eval_loss)` (with safe overflow guard)
|
| 30 |
+
* **Adapter output**
|
| 31 |
+
|
| 32 |
+
* saves the final/best adapter to `run_dir/best_adapter`
|
| 33 |
+
* **Merge workflow**
|
| 34 |
+
|
| 35 |
+
* `--merge-only` merges an existing adapter later
|
| 36 |
+
* merge is done **on CPU** to avoid GPU OOM
|
| 37 |
+
* merged model is stored under the configured merge output directory (relative to `run_dir` if a relative path)
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## Repository layout (outputs)
|
| 42 |
+
|
| 43 |
+
A run produces the following structure under `run.run_dir`:
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
runs/<run_name>/
|
| 47 |
+
├─ checkpoints/ # trainer checkpoints (for resume)
|
| 48 |
+
├─ best_adapter/ # saved LoRA adapter
|
| 49 |
+
├─ logs/
|
| 50 |
+
│ ├─ train.jsonl # step-wise training logs
|
| 51 |
+
│ └─ eval.jsonl # eval logs (eval_loss + perplexity)
|
| 52 |
+
├─ eval_final.json # final eval metrics summary (if eval is enabled)
|
| 53 |
+
└─ config_resolved.yaml # exact config used for the run
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
If merge is used, the merged model is written to:
|
| 57 |
+
|
| 58 |
+
* `run_dir/<merge.output_dir>` if `merge.output_dir` is relative (e.g. `./merged_model`)
|
| 59 |
+
* or the absolute path if it is absolute.
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## Supported training modes
|
| 64 |
+
|
| 65 |
+
### 1) LoRA vs QLoRA (same script)
|
| 66 |
+
|
| 67 |
+
* **QLoRA** happens when `model.use_4bit: true`
|
| 68 |
+
|
| 69 |
+
* base weights are loaded in 4‑bit using bitsandbytes
|
| 70 |
+
* training updates only LoRA parameters
|
| 71 |
+
* **LoRA** happens when `model.use_4bit: false`
|
| 72 |
+
|
| 73 |
+
* base weights are loaded in fp16/bf16 (as configured)
|
| 74 |
+
* training updates only LoRA parameters
|
| 75 |
+
|
| 76 |
+
No “full finetune” mode is enabled by default in this runner.
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## Data pipeline (CPT behavior)
|
| 81 |
+
|
| 82 |
+
### Input format
|
| 83 |
+
|
| 84 |
+
* JSONL file where each line contains a text field (default `"text"`).
|
| 85 |
+
* Example:
|
| 86 |
+
|
| 87 |
+
* `{"text": "some training text..."}`
|
| 88 |
+
|
| 89 |
+
### Packing (token stream → fixed blocks)
|
| 90 |
+
|
| 91 |
+
* Each sample is tokenized without truncation.
|
| 92 |
+
* An **EOS token is appended** per document to preserve boundaries.
|
| 93 |
+
* Token lists are concatenated and converted into **fixed‑length blocks** of `data.block_size`.
|
| 94 |
+
|
| 95 |
+
Two modes:
|
| 96 |
+
|
| 97 |
+
* **`drop` (strict CPT):** remainder tokens that don’t fill a full block are discarded.
|
| 98 |
+
* **`pad` (debug/small data):** remainder is padded to block_size:
|
| 99 |
+
|
| 100 |
+
* `attention_mask = 0` for padded positions
|
| 101 |
+
* `labels = -100` for padded positions (loss masking)
|
| 102 |
+
|
| 103 |
+
This is what allowed training to proceed even with tiny dummy datasets at `block_size=1024`.
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
## Logging
|
| 108 |
+
|
| 109 |
+
Trainer‑Kit writes **machine‑readable logs** in JSONL.
|
| 110 |
+
|
| 111 |
+
### Training logs (`logs/train.jsonl`)
|
| 112 |
+
|
| 113 |
+
Includes entries with:
|
| 114 |
+
|
| 115 |
+
* `step`
|
| 116 |
+
* `loss`
|
| 117 |
+
* `grad_norm`
|
| 118 |
+
* `learning_rate`
|
| 119 |
+
* `progress_pct` (step progress when `max_steps` is active)
|
| 120 |
+
* ETA estimation
|
| 121 |
+
|
| 122 |
+
### Eval logs (`logs/eval.jsonl`)
|
| 123 |
+
|
| 124 |
+
Includes:
|
| 125 |
+
|
| 126 |
+
* `eval_loss`
|
| 127 |
+
* `perplexity`
|
| 128 |
+
|
| 129 |
+
Notes:
|
| 130 |
+
|
| 131 |
+
* When using `max_steps`, the Trainer’s internal `epoch` counter can grow unexpectedly on tiny datasets (because steps/epoch becomes ~1).
|
| 132 |
+
**Use `progress_pct` as the reliable indicator** for step‑based runs.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## Checkpointing and resume
|
| 137 |
+
|
| 138 |
+
The trainer saves checkpoints under:
|
| 139 |
+
|
| 140 |
+
* `run_dir/checkpoints/`
|
| 141 |
+
|
| 142 |
+
Resume options:
|
| 143 |
+
|
| 144 |
+
* `resume_from_checkpoint: "auto"` → picks the latest checkpoint automatically
|
| 145 |
+
* `resume_from_checkpoint: "/path/to/checkpoint"` → resumes from a specific checkpoint
|
| 146 |
+
* `resume_from_checkpoint: null` → fresh run
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## Merging adapters into a final model
|
| 151 |
+
|
| 152 |
+
Trainer‑Kit supports exporting a merged model:
|
| 153 |
+
|
| 154 |
+
### Merge after training
|
| 155 |
+
|
| 156 |
+
* Enable merge in config (`merge.enabled: true`)
|
| 157 |
+
* The script will:
|
| 158 |
+
|
| 159 |
+
1. save the adapter
|
| 160 |
+
2. free GPU memory
|
| 161 |
+
3. reload base model on **CPU**
|
| 162 |
+
4. load adapter
|
| 163 |
+
5. `merge_and_unload()`
|
| 164 |
+
6. save final merged model
|
| 165 |
+
|
| 166 |
+
### Merge later
|
| 167 |
+
|
| 168 |
+
Run:
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
python run_cpt.py --config config.yaml --merge-only
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
This skips training and merges `run_dir/best_adapter` into the base model.
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## How to run
|
| 179 |
+
|
| 180 |
+
### Train
|
| 181 |
+
|
| 182 |
+
```
|
| 183 |
+
python run_cpt.py --config config.yaml
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### Merge only
|
| 187 |
+
|
| 188 |
+
```
|
| 189 |
+
python run_cpt.py --config config.yaml --merge-only
|
trainer-kit/CPT-14b/README_instruct.md
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Instruction Fine-Tuning Script
|
| 2 |
+
|
| 3 |
+
This script (`run_instruct.py`) is designed for fine-tuning language models on instruction-following tasks. It's based on the original CPT script but adapted specifically for instruction input/output pairs.
|
| 4 |
+
|
| 5 |
+
## Key Differences from CPT
|
| 6 |
+
|
| 7 |
+
1. **Data Format**: Handles structured instruction data with separate fields for instruction, input, and output
|
| 8 |
+
2. **Formatting Options**: Supports multiple instruction formats (ChatML, Alpaca, custom templates)
|
| 9 |
+
3. **No Text Packing**: Each example is treated as a complete instruction-response pair
|
| 10 |
+
4. **Proper Loss Masking**: Loss is only computed on the response/output portion, not on the instruction and input
|
| 11 |
+
5. **Automatic Label Creation**: Labels are automatically created with -100 masking for instruction tokens
|
| 12 |
+
|
| 13 |
+
## Supported Data Formats
|
| 14 |
+
|
| 15 |
+
### JSONL Structure
|
| 16 |
+
Each line should be a JSON object with the following fields:
|
| 17 |
+
```json
|
| 18 |
+
{
|
| 19 |
+
"instruction": "Your instruction here",
|
| 20 |
+
"input": "Optional input context (can be empty string)",
|
| 21 |
+
"output": "Expected response"
|
| 22 |
+
}
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Formatting Options
|
| 26 |
+
|
| 27 |
+
#### 1. ChatML Format (Default)
|
| 28 |
+
Uses the model's chat template with system/user/assistant roles:
|
| 29 |
+
```yaml
|
| 30 |
+
data:
|
| 31 |
+
format_type: "chatml"
|
| 32 |
+
system_prompt: "You are a helpful assistant."
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
#### 2. Alpaca Format
|
| 36 |
+
Uses the classic Alpaca instruction format:
|
| 37 |
+
```yaml
|
| 38 |
+
data:
|
| 39 |
+
format_type: "alpaca"
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
#### 3. Custom Format
|
| 43 |
+
Define your own template:
|
| 44 |
+
```yaml
|
| 45 |
+
data:
|
| 46 |
+
format_type: "custom"
|
| 47 |
+
custom_template: "Instruction: {instruction}\nInput: {input}\nOutput: {output}"
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Configuration
|
| 51 |
+
|
| 52 |
+
Key configuration options in `config_instruct.yaml`:
|
| 53 |
+
|
| 54 |
+
### Data Configuration
|
| 55 |
+
```yaml
|
| 56 |
+
data:
|
| 57 |
+
train_jsonl: "path/to/your/train.jsonl"
|
| 58 |
+
eval_jsonl: "path/to/your/eval.jsonl" # optional
|
| 59 |
+
eval_split_ratio: 0.1 # if no eval file provided
|
| 60 |
+
|
| 61 |
+
# Field names in your data
|
| 62 |
+
instruction_field: "instruction"
|
| 63 |
+
input_field: "input"
|
| 64 |
+
output_field: "output"
|
| 65 |
+
|
| 66 |
+
# Formatting
|
| 67 |
+
format_type: "chatml" # "chatml" | "alpaca" | "custom"
|
| 68 |
+
system_prompt: "You are a helpful assistant."
|
| 69 |
+
|
| 70 |
+
# Tokenization
|
| 71 |
+
max_length: 2048
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### Training Configuration
|
| 75 |
+
```yaml
|
| 76 |
+
train:
|
| 77 |
+
max_steps: 100
|
| 78 |
+
num_train_epochs: 3
|
| 79 |
+
per_device_train_batch_size: 1
|
| 80 |
+
gradient_accumulation_steps: 16
|
| 81 |
+
learning_rate: 5e-5
|
| 82 |
+
# ... other training parameters
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Usage
|
| 86 |
+
|
| 87 |
+
### Basic Usage
|
| 88 |
+
```bash
|
| 89 |
+
python run_instruct.py --config config_instruct.yaml
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Merge Only (after training)
|
| 93 |
+
```bash
|
| 94 |
+
python run_instruct.py --config config_instruct.yaml --merge-only
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
## Example Data Format
|
| 98 |
+
|
| 99 |
+
See `instruct_data.jsonl` for examples of the expected data format. Here are a few examples:
|
| 100 |
+
|
| 101 |
+
```json
|
| 102 |
+
{"instruction": "What is the capital of France?", "input": "", "output": "The capital of France is Paris."}
|
| 103 |
+
|
| 104 |
+
{"instruction": "Translate the following English text to French.", "input": "Hello, how are you today?", "output": "Bonjour, comment allez-vous aujourd'hui?"}
|
| 105 |
+
|
| 106 |
+
{"instruction": "Write a Python function that calculates factorial.", "input": "", "output": "def factorial(n):\n if n < 0:\n raise ValueError(...)"}
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Key Features
|
| 110 |
+
|
| 111 |
+
1. **Multiple Format Support**: ChatML, Alpaca, and custom templates
|
| 112 |
+
2. **Flexible Field Mapping**: Configure custom field names for your data
|
| 113 |
+
3. **Proper Loss Masking**: Only computes loss on the response portion
|
| 114 |
+
4. **PEFT/LoRA Support**: Efficient fine-tuning with LoRA
|
| 115 |
+
5. **Evaluation Support**: Automatic evaluation split or separate eval file
|
| 116 |
+
6. **Checkpointing**: Resume training from checkpoints
|
| 117 |
+
7. **Model Merging**: Merge trained adapters with base model
|
| 118 |
+
|
| 119 |
+
## Best Practices
|
| 120 |
+
|
| 121 |
+
1. **Data Quality**: Ensure your instruction-response pairs are high-quality and consistent
|
| 122 |
+
2. **Format Consistency**: Use the same format for training and inference
|
| 123 |
+
3. **System Prompts**: Choose appropriate system prompts for your use case
|
| 124 |
+
4. **Token Length**: Set appropriate `max_length` based on your model and data
|
| 125 |
+
5. **Batch Size**: Adjust batch size and gradient accumulation based on your GPU memory
|
| 126 |
+
|
| 127 |
+
## Troubleshooting
|
| 128 |
+
|
| 129 |
+
### Common Issues
|
| 130 |
+
|
| 131 |
+
1. **CUDA Out of Memory**: Reduce batch size or enable 4-bit quantization
|
| 132 |
+
2. **Slow Training**: Increase `gradient_accumulation_steps` or reduce `max_length`
|
| 133 |
+
3. **Poor Quality**: Check data format consistency and quality
|
| 134 |
+
4. **Tokenizer Issues**: Ensure your model has proper chat template support
|
| 135 |
+
|
| 136 |
+
### Debug Mode
|
| 137 |
+
Add logging to see formatted examples:
|
| 138 |
+
```python
|
| 139 |
+
# In format_instruction function, add:
|
| 140 |
+
print(f"Formatted: {formatted_text}")
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
## File Structure
|
| 144 |
+
|
| 145 |
+
```
|
| 146 |
+
CPT/
|
| 147 |
+
├── run_instruct.py # Main instruction fine-tuning script
|
| 148 |
+
├── config_instruct.yaml # Configuration file
|
| 149 |
+
├── instruct_data.jsonl # Example instruction data
|
| 150 |
+
├── README_instruct.md # This documentation
|
| 151 |
+
└── runs/ # Training outputs
|
| 152 |
+
└── instruct_run_v1/
|
| 153 |
+
├── logs/
|
| 154 |
+
├── checkpoints/
|
| 155 |
+
├── best_adapter/
|
| 156 |
+
└── final_model/
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## Migration from CPT
|
| 160 |
+
|
| 161 |
+
To migrate from the original CPT script:
|
| 162 |
+
|
| 163 |
+
1. Convert your text data to instruction format
|
| 164 |
+
2. Update your configuration file
|
| 165 |
+
3. Choose appropriate formatting options
|
| 166 |
+
4. Adjust training parameters (instruction fine-tuning typically needs fewer steps)
|
| 167 |
+
|
| 168 |
+
The script maintains the same CLI interface and most configuration options for easy migration.
|
trainer-kit/CPT-14b/commands.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Commands
|
| 2 |
+
|
| 3 |
+
Train (no merge):
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
python run_cpt.py --config config.yaml
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
Merge later:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
python run_cpt.py --config config.yaml --merge-only
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
---
|
trainer-kit/CPT-14b/config.yaml
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/cpt_run_14b"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# WandB integration for experiment tracking
|
| 6 |
+
wandb:
|
| 7 |
+
enabled: true # Set to true to enable wandb logging
|
| 8 |
+
project: "cpt-training" # WandB project name
|
| 9 |
+
entity: null # WandB entity/team (optional)
|
| 10 |
+
name: null # Run name (optional, will auto-generate if null)
|
| 11 |
+
tags: ["cpt-lora","sft-14b"] # List of tags for the run (e.g., ["lora", "qlora", "experiment-1"])
|
| 12 |
+
notes: null # Run description/notes (optional)
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
# Local model path (no download)
|
| 16 |
+
repo_id: "/workspace/Models/Qwen2.5-Coder-14B"
|
| 17 |
+
revision: null
|
| 18 |
+
|
| 19 |
+
# Used only when repo_id is a HF repo (not a local path)
|
| 20 |
+
base_local_dir: "base_model"
|
| 21 |
+
|
| 22 |
+
trust_remote_code: true
|
| 23 |
+
tokenizer_use_fast: true
|
| 24 |
+
device_map: "auto"
|
| 25 |
+
|
| 26 |
+
torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
|
| 27 |
+
|
| 28 |
+
# QLoRA
|
| 29 |
+
use_4bit: false
|
| 30 |
+
bnb_4bit_quant_type: "nf4"
|
| 31 |
+
bnb_4bit_use_double_quant: false
|
| 32 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 33 |
+
|
| 34 |
+
# optional: "flash_attention_2" | "sdpa" | null
|
| 35 |
+
attn_implementation: null
|
| 36 |
+
|
| 37 |
+
data:
|
| 38 |
+
train_jsonl: "all_data_with_descriptions.jsonl"
|
| 39 |
+
eval_jsonl: null
|
| 40 |
+
eval_split_ratio: 0.1
|
| 41 |
+
text_field: "text"
|
| 42 |
+
block_size: 4096
|
| 43 |
+
shuffle: true
|
| 44 |
+
num_proc: 4
|
| 45 |
+
|
| 46 |
+
# ✅ NEW: packing behavior
|
| 47 |
+
# "drop" = strict CPT (drop remainder)
|
| 48 |
+
# "pad" = pad remainder to block_size + loss mask (-100) + attention_mask=0
|
| 49 |
+
pack_mode: "pad"
|
| 50 |
+
|
| 51 |
+
peft:
|
| 52 |
+
enabled: true
|
| 53 |
+
r: 32
|
| 54 |
+
lora_alpha: 64
|
| 55 |
+
lora_dropout: 0.05
|
| 56 |
+
bias: "none"
|
| 57 |
+
target_modules: "auto"
|
| 58 |
+
|
| 59 |
+
train:
|
| 60 |
+
# max_steps: 1000
|
| 61 |
+
num_train_epochs: 2
|
| 62 |
+
|
| 63 |
+
per_device_train_batch_size: 1
|
| 64 |
+
per_device_eval_batch_size: 1
|
| 65 |
+
gradient_accumulation_steps: 16
|
| 66 |
+
|
| 67 |
+
learning_rate: 2e-5
|
| 68 |
+
weight_decay: 0.0
|
| 69 |
+
warmup_ratio: 0.1
|
| 70 |
+
lr_scheduler_type: "cosine"
|
| 71 |
+
|
| 72 |
+
optim: "paged_adamw_8bit"
|
| 73 |
+
max_grad_norm: 1.0
|
| 74 |
+
gradient_checkpointing: true
|
| 75 |
+
|
| 76 |
+
logging_steps: 1
|
| 77 |
+
save_strategy: "steps"
|
| 78 |
+
save_steps: 100
|
| 79 |
+
save_total_limit: 7
|
| 80 |
+
|
| 81 |
+
evaluation_strategy: "steps"
|
| 82 |
+
eval_steps: 50
|
| 83 |
+
load_best_model_at_end: true
|
| 84 |
+
|
| 85 |
+
resume_from_checkpoint: "auto"
|
| 86 |
+
|
| 87 |
+
merge:
|
| 88 |
+
enabled: true
|
| 89 |
+
merged_dtype: "float16"
|
| 90 |
+
max_shard_size: "2GB"
|
| 91 |
+
output_dir: "./merged_14b_cpt_lora"
|
trainer-kit/CPT-14b/dummy_data.jsonl
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"text": "This is a test sentence for the dummy dataset."}
|
| 2 |
+
{"text": "Another sentence to check if training works."}
|
| 3 |
+
{"text": "We need enough data to form a batch."}
|
| 4 |
+
{"text": "FSDP and LoRA are cool technologies."}
|
| 5 |
+
{"text": "Fine-tuning LLMs is fun and useful."}
|
| 6 |
+
{"text": "This is the end of the dummy dataset."}
|
trainer-kit/CPT-14b/requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
transformers>=4.41.0
|
| 4 |
+
datasets>=2.18.0
|
| 5 |
+
accelerate>=0.30.0
|
| 6 |
+
|
| 7 |
+
# PEFT / QLoRA
|
| 8 |
+
peft>=0.11.1
|
| 9 |
+
bitsandbytes>=0.43.1
|
| 10 |
+
|
| 11 |
+
# Hugging Face Hub (local + download support)
|
| 12 |
+
huggingface_hub>=0.23.0
|
| 13 |
+
|
| 14 |
+
# Config + utilities
|
| 15 |
+
pyyaml>=6.0
|
| 16 |
+
tqdm>=4.66.0
|
| 17 |
+
|
| 18 |
+
# Optional but recommended (tokenizers speed)
|
| 19 |
+
tokenizers>=0.15.0
|
| 20 |
+
safetensors>=0.4.2
|
| 21 |
+
# Optional (for eval)
|
| 22 |
+
rouge-score>=0.1.2
|
| 23 |
+
|
| 24 |
+
# Experiment tracking
|
| 25 |
+
wandb>=0.16.0
|
trainer-kit/CPT-14b/run_cpt.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import inspect # Added for Transformers version compatibility
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import yaml
|
| 11 |
+
from datasets import load_dataset, DatasetDict
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoModelForCausalLM,
|
| 15 |
+
AutoTokenizer,
|
| 16 |
+
PreTrainedTokenizerFast,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
Trainer,
|
| 19 |
+
TrainerCallback,
|
| 20 |
+
default_data_collator,
|
| 21 |
+
set_seed,
|
| 22 |
+
)
|
| 23 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 24 |
+
from peft import (
|
| 25 |
+
LoraConfig,
|
| 26 |
+
get_peft_model,
|
| 27 |
+
prepare_model_for_kbit_training,
|
| 28 |
+
PeftModel,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from transformers import BitsAndBytesConfig
|
| 33 |
+
except ImportError: # older transformers
|
| 34 |
+
BitsAndBytesConfig = None
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
import wandb
|
| 38 |
+
WANDB_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
WANDB_AVAILABLE = False
|
| 41 |
+
wandb = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# --------------------------
|
| 45 |
+
# Helpers
|
| 46 |
+
# --------------------------
|
| 47 |
+
|
| 48 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 49 |
+
s = (s or "").lower()
|
| 50 |
+
if s in ("float16", "fp16"):
|
| 51 |
+
return torch.float16
|
| 52 |
+
if s in ("bfloat16", "bf16"):
|
| 53 |
+
return torch.bfloat16
|
| 54 |
+
if s in ("float32", "fp32"):
|
| 55 |
+
return torch.float32
|
| 56 |
+
raise ValueError(f"Unknown torch_dtype: {s}")
|
| 57 |
+
|
| 58 |
+
def _now_iso() -> str:
|
| 59 |
+
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
| 60 |
+
|
| 61 |
+
def _safe_exp(x: float) -> float:
|
| 62 |
+
x = min(float(x), 50.0)
|
| 63 |
+
return float(math.exp(x))
|
| 64 |
+
|
| 65 |
+
def _ensure_dir(p: Path) -> Path:
|
| 66 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
return p
|
| 68 |
+
|
| 69 |
+
def _looks_like_model_dir(p: Path) -> bool:
|
| 70 |
+
if not p.exists() or not p.is_dir():
|
| 71 |
+
return False
|
| 72 |
+
if (p / "config.json").exists():
|
| 73 |
+
return True
|
| 74 |
+
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
|
| 75 |
+
return True
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
def _detect_text_field(example: Dict[str, Any]) -> Optional[str]:
|
| 79 |
+
for k, v in example.items():
|
| 80 |
+
if isinstance(v, str) and v.strip():
|
| 81 |
+
return k
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
def _load_tokenizer(base_dir: Path, use_fast: bool, trust_remote_code: bool):
|
| 85 |
+
try:
|
| 86 |
+
return AutoTokenizer.from_pretrained(
|
| 87 |
+
str(base_dir),
|
| 88 |
+
use_fast=use_fast,
|
| 89 |
+
trust_remote_code=trust_remote_code,
|
| 90 |
+
)
|
| 91 |
+
except ValueError as e:
|
| 92 |
+
if "TokenizersBackend" not in str(e):
|
| 93 |
+
raise
|
| 94 |
+
tok_file = base_dir / "tokenizer.json"
|
| 95 |
+
tok_cfg_path = base_dir / "tokenizer_config.json"
|
| 96 |
+
if not tok_file.exists():
|
| 97 |
+
raise
|
| 98 |
+
|
| 99 |
+
tok_kwargs: Dict[str, Any] = {}
|
| 100 |
+
if tok_cfg_path.exists():
|
| 101 |
+
with tok_cfg_path.open("r", encoding="utf-8") as f:
|
| 102 |
+
tok_cfg = json.load(f)
|
| 103 |
+
for key in ("bos_token", "eos_token", "pad_token", "unk_token", "model_max_length"):
|
| 104 |
+
if tok_cfg.get(key) is not None:
|
| 105 |
+
tok_kwargs[key] = tok_cfg[key]
|
| 106 |
+
extra = tok_cfg.get("additional_special_tokens") or tok_cfg.get("extra_special_tokens")
|
| 107 |
+
if extra:
|
| 108 |
+
tok_kwargs["additional_special_tokens"] = extra
|
| 109 |
+
|
| 110 |
+
return PreTrainedTokenizerFast(tokenizer_file=str(tok_file), **tok_kwargs)
|
| 111 |
+
|
| 112 |
+
def _infer_target_modules(model) -> List[str]:
|
| 113 |
+
names = set()
|
| 114 |
+
for n, _ in model.named_modules():
|
| 115 |
+
names.add(n.split(".")[-1])
|
| 116 |
+
|
| 117 |
+
for group in [
|
| 118 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 119 |
+
["Wqkv", "out_proj"],
|
| 120 |
+
["query_key_value", "dense"],
|
| 121 |
+
["c_attn", "c_proj"],
|
| 122 |
+
]:
|
| 123 |
+
if all(x in names for x in group):
|
| 124 |
+
return group
|
| 125 |
+
|
| 126 |
+
fallback = [x for x in ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense"] if x in names]
|
| 127 |
+
if fallback:
|
| 128 |
+
return fallback
|
| 129 |
+
|
| 130 |
+
raise ValueError("Could not auto-infer target_modules. Set peft.target_modules explicitly.")
|
| 131 |
+
|
| 132 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 133 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# --------------------------
|
| 137 |
+
# Wandb Integration
|
| 138 |
+
# --------------------------
|
| 139 |
+
|
| 140 |
+
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
|
| 141 |
+
"""Initialize Wandb if enabled in configuration."""
|
| 142 |
+
wandb_cfg = cfg.get("wandb", {})
|
| 143 |
+
|
| 144 |
+
if not wandb_cfg.get("enabled", False):
|
| 145 |
+
print("Wandb logging disabled")
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
if not WANDB_AVAILABLE:
|
| 149 |
+
print("Wandb not available. Install with: pip install wandb")
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
# Extract wandb configuration
|
| 153 |
+
project = wandb_cfg.get("project", "cpt-training")
|
| 154 |
+
entity = wandb_cfg.get("entity", None)
|
| 155 |
+
name = wandb_cfg.get("name", None)
|
| 156 |
+
tags = wandb_cfg.get("tags", [])
|
| 157 |
+
notes = wandb_cfg.get("notes", None)
|
| 158 |
+
|
| 159 |
+
# Initialize wandb
|
| 160 |
+
try:
|
| 161 |
+
wandb.init(
|
| 162 |
+
project=project,
|
| 163 |
+
entity=entity,
|
| 164 |
+
name=name,
|
| 165 |
+
tags=tags,
|
| 166 |
+
notes=notes,
|
| 167 |
+
dir=str(run_dir),
|
| 168 |
+
config={
|
| 169 |
+
"model": cfg.get("model", {}),
|
| 170 |
+
"data": cfg.get("data", {}),
|
| 171 |
+
"peft": cfg.get("peft", {}),
|
| 172 |
+
"train": cfg.get("train", {}),
|
| 173 |
+
"run_dir": str(run_dir),
|
| 174 |
+
}
|
| 175 |
+
)
|
| 176 |
+
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
|
| 177 |
+
return wandb
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Failed to initialize Wandb: {e}")
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def finish_wandb():
|
| 184 |
+
"""Finish Wandb run if active."""
|
| 185 |
+
if WANDB_AVAILABLE and wandb.run is not None:
|
| 186 |
+
wandb.finish()
|
| 187 |
+
print("Wandb run finished")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# --------------------------
|
| 191 |
+
# JSONL Logger Callback
|
| 192 |
+
# --------------------------
|
| 193 |
+
|
| 194 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 195 |
+
def __init__(self, run_dir: Path):
|
| 196 |
+
self.run_dir = run_dir
|
| 197 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 198 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 199 |
+
self.start_time = None
|
| 200 |
+
|
| 201 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 202 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 203 |
+
return None
|
| 204 |
+
elapsed = time.time() - self.start_time
|
| 205 |
+
sec_per_step = elapsed / global_step
|
| 206 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 207 |
+
h = int(remaining // 3600)
|
| 208 |
+
m = int((remaining % 3600) // 60)
|
| 209 |
+
s = int(remaining % 60)
|
| 210 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 211 |
+
|
| 212 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 213 |
+
self.start_time = time.time()
|
| 214 |
+
|
| 215 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 216 |
+
if not logs:
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 220 |
+
progress_pct = (100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 221 |
+
epoch_pct = None
|
| 222 |
+
if state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0:
|
| 223 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 224 |
+
|
| 225 |
+
payload = {
|
| 226 |
+
"ts": _now_iso(),
|
| 227 |
+
"event": "train_log",
|
| 228 |
+
"step": int(state.global_step),
|
| 229 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 230 |
+
"progress_pct": round(progress_pct, 2) if progress_pct is not None else None,
|
| 231 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 232 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 233 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 234 |
+
**logs,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 238 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 239 |
+
|
| 240 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 241 |
+
if not metrics:
|
| 242 |
+
return
|
| 243 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 244 |
+
ppl = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 245 |
+
|
| 246 |
+
payload = {
|
| 247 |
+
"ts": _now_iso(),
|
| 248 |
+
"event": "eval",
|
| 249 |
+
"step": int(state.global_step),
|
| 250 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 251 |
+
**metrics,
|
| 252 |
+
"perplexity": ppl,
|
| 253 |
+
}
|
| 254 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 255 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# --------------------------
|
| 259 |
+
# Data Pipeline (EOS + Packing)
|
| 260 |
+
# --------------------------
|
| 261 |
+
|
| 262 |
+
def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
|
| 263 |
+
data_cfg = cfg["data"]
|
| 264 |
+
train_path = data_cfg["train_jsonl"]
|
| 265 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 266 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 267 |
+
text_field = data_cfg.get("text_field", "text")
|
| 268 |
+
block_size = int(data_cfg.get("block_size", 2048))
|
| 269 |
+
shuffle = bool(data_cfg.get("shuffle", True))
|
| 270 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 271 |
+
|
| 272 |
+
pack_mode = str(data_cfg.get("pack_mode", "drop")).lower().strip()
|
| 273 |
+
if pack_mode not in ("drop", "pad"):
|
| 274 |
+
raise ValueError(f"data.pack_mode must be 'drop' or 'pad', got: {pack_mode}")
|
| 275 |
+
|
| 276 |
+
eos_id = tokenizer.eos_token_id
|
| 277 |
+
if eos_id is None:
|
| 278 |
+
raise ValueError("Tokenizer has no eos_token_id; CPT packing needs an EOS delimiter.")
|
| 279 |
+
|
| 280 |
+
if tokenizer.pad_token_id is None:
|
| 281 |
+
# safe default for many causal LMs
|
| 282 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 283 |
+
pad_id = tokenizer.pad_token_id
|
| 284 |
+
|
| 285 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 286 |
+
|
| 287 |
+
if eval_path:
|
| 288 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 289 |
+
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
|
| 290 |
+
else:
|
| 291 |
+
if 0.0 < split_ratio < 1.0:
|
| 292 |
+
split = ds["train"].train_test_split(test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)))
|
| 293 |
+
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
|
| 294 |
+
else:
|
| 295 |
+
dsd = DatasetDict({"train": ds["train"], "eval": None})
|
| 296 |
+
|
| 297 |
+
if text_field not in dsd["train"].column_names:
|
| 298 |
+
auto_field = _detect_text_field(dsd["train"][0])
|
| 299 |
+
if not auto_field:
|
| 300 |
+
raise ValueError(f"Could not find text field. Columns: {dsd['train'].column_names}")
|
| 301 |
+
text_field = auto_field
|
| 302 |
+
|
| 303 |
+
def tokenize_fn(examples):
|
| 304 |
+
out = tokenizer(
|
| 305 |
+
examples[text_field],
|
| 306 |
+
add_special_tokens=False,
|
| 307 |
+
truncation=False,
|
| 308 |
+
padding=False,
|
| 309 |
+
)
|
| 310 |
+
if "token_type_ids" in out:
|
| 311 |
+
del out["token_type_ids"]
|
| 312 |
+
# Add EOS between docs
|
| 313 |
+
out["input_ids"] = [ids + [eos_id] for ids in out["input_ids"]]
|
| 314 |
+
out["attention_mask"] = [m + [1] for m in out["attention_mask"]]
|
| 315 |
+
return out
|
| 316 |
+
|
| 317 |
+
tokenized_train = dsd["train"].map(
|
| 318 |
+
tokenize_fn,
|
| 319 |
+
batched=True,
|
| 320 |
+
num_proc=num_proc,
|
| 321 |
+
remove_columns=dsd["train"].column_names,
|
| 322 |
+
desc="Tokenizing train",
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
tokenized_eval = None
|
| 326 |
+
if dsd["eval"] is not None:
|
| 327 |
+
tokenized_eval = dsd["eval"].map(
|
| 328 |
+
tokenize_fn,
|
| 329 |
+
batched=True,
|
| 330 |
+
num_proc=num_proc,
|
| 331 |
+
remove_columns=dsd["eval"].column_names,
|
| 332 |
+
desc="Tokenizing eval",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def group_texts(examples):
|
| 336 |
+
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
|
| 337 |
+
total_length = len(concatenated["input_ids"])
|
| 338 |
+
|
| 339 |
+
if total_length == 0:
|
| 340 |
+
return {"input_ids": [], "attention_mask": [], "labels": []}
|
| 341 |
+
|
| 342 |
+
full_len = (total_length // block_size) * block_size
|
| 343 |
+
blocks_input, blocks_attn, blocks_labels = [], [], []
|
| 344 |
+
|
| 345 |
+
# full blocks
|
| 346 |
+
for i in range(0, full_len, block_size):
|
| 347 |
+
chunk = concatenated["input_ids"][i:i + block_size]
|
| 348 |
+
attn = concatenated["attention_mask"][i:i + block_size]
|
| 349 |
+
blocks_input.append(chunk)
|
| 350 |
+
blocks_attn.append(attn)
|
| 351 |
+
blocks_labels.append(chunk.copy())
|
| 352 |
+
|
| 353 |
+
# remainder
|
| 354 |
+
remainder = total_length - full_len
|
| 355 |
+
if remainder > 0 and pack_mode == "pad":
|
| 356 |
+
chunk = concatenated["input_ids"][full_len:full_len + remainder]
|
| 357 |
+
attn = concatenated["attention_mask"][full_len:full_len + remainder]
|
| 358 |
+
|
| 359 |
+
pad_len = block_size - remainder
|
| 360 |
+
chunk_padded = chunk + [pad_id] * pad_len
|
| 361 |
+
attn_padded = attn + [0] * pad_len
|
| 362 |
+
|
| 363 |
+
labels = chunk_padded.copy()
|
| 364 |
+
labels[-pad_len:] = [-100] * pad_len # loss mask
|
| 365 |
+
|
| 366 |
+
blocks_input.append(chunk_padded)
|
| 367 |
+
blocks_attn.append(attn_padded)
|
| 368 |
+
blocks_labels.append(labels)
|
| 369 |
+
|
| 370 |
+
return {
|
| 371 |
+
"input_ids": blocks_input,
|
| 372 |
+
"attention_mask": blocks_attn,
|
| 373 |
+
"labels": blocks_labels,
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
tokenized_train = tokenized_train.map(
|
| 377 |
+
group_texts,
|
| 378 |
+
batched=True,
|
| 379 |
+
num_proc=num_proc,
|
| 380 |
+
desc=f"Packing train blocks (mode={pack_mode})",
|
| 381 |
+
)
|
| 382 |
+
if tokenized_eval is not None:
|
| 383 |
+
tokenized_eval = tokenized_eval.map(
|
| 384 |
+
group_texts,
|
| 385 |
+
batched=True,
|
| 386 |
+
num_proc=num_proc,
|
| 387 |
+
desc=f"Packing eval blocks (mode={pack_mode})",
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if len(tokenized_train) == 0:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
"Train dataset is empty after packing. "
|
| 393 |
+
"Either increase data, reduce block_size, or set data.pack_mode='pad'."
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
if shuffle:
|
| 397 |
+
tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 398 |
+
|
| 399 |
+
return tokenized_train, tokenized_eval
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# --------------------------
|
| 403 |
+
# Model Loading + PEFT
|
| 404 |
+
# --------------------------
|
| 405 |
+
|
| 406 |
+
def _select_model_loader(base_dir: Path):
|
| 407 |
+
cfg_path = base_dir / "config.json"
|
| 408 |
+
if not cfg_path.exists():
|
| 409 |
+
return {"kind": "causal", "arch": None}
|
| 410 |
+
with cfg_path.open("r", encoding="utf-8") as f:
|
| 411 |
+
cfg = json.load(f)
|
| 412 |
+
arch = cfg.get("architectures") or []
|
| 413 |
+
arch_name = arch[0] if arch else None
|
| 414 |
+
if any("ForConditionalGeneration" in a for a in arch):
|
| 415 |
+
return {"kind": "conditional", "arch": arch_name}
|
| 416 |
+
return {"kind": "causal", "arch": arch_name}
|
| 417 |
+
|
| 418 |
+
def _resolve_model_class(arch_name: str):
|
| 419 |
+
import transformers
|
| 420 |
+
cls = getattr(transformers, arch_name, None)
|
| 421 |
+
if cls is None:
|
| 422 |
+
raise ValueError(f"Model class '{arch_name}' is not available in installed transformers.")
|
| 423 |
+
return cls
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
|
| 427 |
+
model_cfg = cfg["model"]
|
| 428 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 429 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 430 |
+
device_map = model_cfg.get("device_map", "auto")
|
| 431 |
+
|
| 432 |
+
tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
|
| 433 |
+
if tokenizer.pad_token is None:
|
| 434 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 435 |
+
|
| 436 |
+
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 437 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 438 |
+
|
| 439 |
+
quant_cfg = None
|
| 440 |
+
if use_4bit:
|
| 441 |
+
if BitsAndBytesConfig is None:
|
| 442 |
+
raise ImportError("BitsAndBytesConfig is not available in this transformers version.")
|
| 443 |
+
quant_cfg = BitsAndBytesConfig(
|
| 444 |
+
load_in_4bit=True,
|
| 445 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
|
| 446 |
+
bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
|
| 447 |
+
bnb_4bit_compute_dtype=_dtype_from_str(model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")),
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
attn_impl = _choose_attn_impl(cfg)
|
| 451 |
+
model_meta = _select_model_loader(base_dir)
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
if model_meta["kind"] == "conditional":
|
| 455 |
+
model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
|
| 456 |
+
if model_cls is None:
|
| 457 |
+
raise ValueError("Conditional model architecture not specified in config.json.")
|
| 458 |
+
model = model_cls.from_pretrained(
|
| 459 |
+
str(base_dir),
|
| 460 |
+
device_map=device_map,
|
| 461 |
+
trust_remote_code=trust_remote_code,
|
| 462 |
+
low_cpu_mem_usage=True,
|
| 463 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 464 |
+
quantization_config=quant_cfg,
|
| 465 |
+
attn_implementation=attn_impl,
|
| 466 |
+
)
|
| 467 |
+
else:
|
| 468 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 469 |
+
str(base_dir),
|
| 470 |
+
device_map=device_map,
|
| 471 |
+
trust_remote_code=trust_remote_code,
|
| 472 |
+
low_cpu_mem_usage=True,
|
| 473 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 474 |
+
quantization_config=quant_cfg,
|
| 475 |
+
attn_implementation=attn_impl,
|
| 476 |
+
)
|
| 477 |
+
except Exception as e:
|
| 478 |
+
if attn_impl is not None:
|
| 479 |
+
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
|
| 480 |
+
print("[warn] Falling back to default attention implementation.")
|
| 481 |
+
if model_meta["kind"] == "conditional":
|
| 482 |
+
model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
|
| 483 |
+
if model_cls is None:
|
| 484 |
+
raise ValueError("Conditional model architecture not specified in config.json.")
|
| 485 |
+
model = model_cls.from_pretrained(
|
| 486 |
+
str(base_dir),
|
| 487 |
+
device_map=device_map,
|
| 488 |
+
trust_remote_code=trust_remote_code,
|
| 489 |
+
low_cpu_mem_usage=True,
|
| 490 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 491 |
+
quantization_config=quant_cfg,
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 495 |
+
str(base_dir),
|
| 496 |
+
device_map=device_map,
|
| 497 |
+
trust_remote_code=trust_remote_code,
|
| 498 |
+
low_cpu_mem_usage=True,
|
| 499 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 500 |
+
quantization_config=quant_cfg,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
return model, tokenizer
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def apply_peft(cfg: Dict[str, Any], model):
|
| 507 |
+
peft_cfg = cfg["peft"]
|
| 508 |
+
model_cfg = cfg["model"]
|
| 509 |
+
tr_cfg = cfg["train"]
|
| 510 |
+
|
| 511 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 512 |
+
return model, None
|
| 513 |
+
|
| 514 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 515 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 516 |
+
|
| 517 |
+
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
|
| 518 |
+
model.gradient_checkpointing_enable()
|
| 519 |
+
if hasattr(model, "config"):
|
| 520 |
+
model.config.use_cache = False
|
| 521 |
+
|
| 522 |
+
if use_4bit:
|
| 523 |
+
model = prepare_model_for_kbit_training(
|
| 524 |
+
model,
|
| 525 |
+
use_gradient_checkpointing=gradient_checkpointing,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 529 |
+
if target_modules == "auto":
|
| 530 |
+
target_modules = _infer_target_modules(model)
|
| 531 |
+
|
| 532 |
+
lora_config = LoraConfig(
|
| 533 |
+
r=int(peft_cfg.get("r", 16)),
|
| 534 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 535 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 536 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 537 |
+
task_type="CAUSAL_LM",
|
| 538 |
+
target_modules=target_modules,
|
| 539 |
+
)
|
| 540 |
+
model = get_peft_model(model, lora_config)
|
| 541 |
+
return model, lora_config
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# --------------------------
|
| 545 |
+
# Merge Logic
|
| 546 |
+
# --------------------------
|
| 547 |
+
|
| 548 |
+
def merge_adapter(cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path):
|
| 549 |
+
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
|
| 550 |
+
|
| 551 |
+
model_cfg = cfg["model"]
|
| 552 |
+
merge_cfg = cfg.get("merge", {})
|
| 553 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 554 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 555 |
+
|
| 556 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 557 |
+
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
|
| 558 |
+
|
| 559 |
+
model_meta = _select_model_loader(base_dir)
|
| 560 |
+
if model_meta["kind"] == "conditional":
|
| 561 |
+
base_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
|
| 562 |
+
if base_cls is None:
|
| 563 |
+
raise ValueError("Conditional model architecture not specified in config.json.")
|
| 564 |
+
base = base_cls.from_pretrained(
|
| 565 |
+
str(base_dir),
|
| 566 |
+
torch_dtype=merged_dtype,
|
| 567 |
+
device_map="cpu",
|
| 568 |
+
low_cpu_mem_usage=True,
|
| 569 |
+
trust_remote_code=trust_remote_code,
|
| 570 |
+
)
|
| 571 |
+
else:
|
| 572 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 573 |
+
str(base_dir),
|
| 574 |
+
torch_dtype=merged_dtype,
|
| 575 |
+
device_map="cpu",
|
| 576 |
+
low_cpu_mem_usage=True,
|
| 577 |
+
trust_remote_code=trust_remote_code,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
merged = PeftModel.from_pretrained(base, str(adapter_dir))
|
| 581 |
+
merged = merged.merge_and_unload()
|
| 582 |
+
|
| 583 |
+
_ensure_dir(final_dir)
|
| 584 |
+
merged.save_pretrained(str(final_dir), safe_serialization=True, max_shard_size=max_shard_size)
|
| 585 |
+
|
| 586 |
+
tok = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
|
| 587 |
+
if tok.pad_token is None:
|
| 588 |
+
tok.pad_token = tok.eos_token
|
| 589 |
+
tok.save_pretrained(str(final_dir))
|
| 590 |
+
|
| 591 |
+
print("--- Merge complete ---")
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# --------------------------
|
| 595 |
+
# Main
|
| 596 |
+
# --------------------------
|
| 597 |
+
|
| 598 |
+
def main():
|
| 599 |
+
ap = argparse.ArgumentParser()
|
| 600 |
+
ap.add_argument("--config", required=True, help="Path to YAML config")
|
| 601 |
+
ap.add_argument("--merge-only", action="store_true", help="Skip training, just merge adapter")
|
| 602 |
+
args = ap.parse_args()
|
| 603 |
+
|
| 604 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 605 |
+
cfg = yaml.safe_load(f)
|
| 606 |
+
|
| 607 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 608 |
+
_ensure_dir(run_dir / "logs")
|
| 609 |
+
|
| 610 |
+
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
|
| 611 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 612 |
+
|
| 613 |
+
model_cfg = cfg["model"]
|
| 614 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 615 |
+
repo_path = Path(repo_id)
|
| 616 |
+
|
| 617 |
+
# ✅ Local model path -> load directly; no download
|
| 618 |
+
if repo_path.exists() and repo_path.is_dir():
|
| 619 |
+
base_dir = repo_path
|
| 620 |
+
if not _looks_like_model_dir(base_dir):
|
| 621 |
+
raise ValueError(f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}")
|
| 622 |
+
else:
|
| 623 |
+
# HF repo_id -> download into run_dir/base_local_dir
|
| 624 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 625 |
+
if not _looks_like_model_dir(base_dir):
|
| 626 |
+
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
|
| 627 |
+
snapshot_download(
|
| 628 |
+
repo_id=repo_id,
|
| 629 |
+
revision=model_cfg.get("revision", None),
|
| 630 |
+
local_dir=str(base_dir),
|
| 631 |
+
local_dir_use_symlinks=False,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
|
| 635 |
+
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
|
| 636 |
+
|
| 637 |
+
merge_cfg = cfg.get("merge", {}) or {}
|
| 638 |
+
if merge_cfg.get("output_dir"):
|
| 639 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 640 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 641 |
+
else:
|
| 642 |
+
final_dir = run_dir / "final_model"
|
| 643 |
+
|
| 644 |
+
# Merge-only
|
| 645 |
+
if args.merge_only:
|
| 646 |
+
if not _looks_like_model_dir(best_adapter_dir):
|
| 647 |
+
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
|
| 648 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 649 |
+
return
|
| 650 |
+
|
| 651 |
+
# Initialize Wandb
|
| 652 |
+
wandb_run = setup_wandb(cfg, run_dir)
|
| 653 |
+
|
| 654 |
+
# Training
|
| 655 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 656 |
+
|
| 657 |
+
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
|
| 658 |
+
model, _ = apply_peft(cfg, model)
|
| 659 |
+
|
| 660 |
+
train_ds, eval_ds = build_datasets(cfg, tokenizer)
|
| 661 |
+
|
| 662 |
+
tr_cfg = cfg["train"]
|
| 663 |
+
|
| 664 |
+
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 665 |
+
use_fp16 = (dtype == torch.float16)
|
| 666 |
+
use_bf16 = (dtype == torch.bfloat16)
|
| 667 |
+
|
| 668 |
+
max_steps = int(tr_cfg.get("max_steps", 0))
|
| 669 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 670 |
+
|
| 671 |
+
# --- Dynamic evaluation strategy parameter handling ---
|
| 672 |
+
ta_params = inspect.signature(TrainingArguments.__init__).parameters
|
| 673 |
+
eval_key = "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
|
| 674 |
+
|
| 675 |
+
# Setup reporting based on wandb availability
|
| 676 |
+
report_to = []
|
| 677 |
+
if wandb_run is not None:
|
| 678 |
+
report_to.append("wandb")
|
| 679 |
+
|
| 680 |
+
desired_ta_kwargs = dict(
|
| 681 |
+
output_dir=str(ckpt_dir),
|
| 682 |
+
max_steps=max_steps if max_steps > 0 else -1,
|
| 683 |
+
num_train_epochs=num_train_epochs,
|
| 684 |
+
|
| 685 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
|
| 686 |
+
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1))),
|
| 687 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
|
| 688 |
+
|
| 689 |
+
learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
|
| 690 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
|
| 691 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 692 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
|
| 693 |
+
|
| 694 |
+
optim=str(tr_cfg.get("optim", "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch")),
|
| 695 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
|
| 696 |
+
|
| 697 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 698 |
+
|
| 699 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 700 |
+
save_steps=int(tr_cfg.get("save_steps", 200)),
|
| 701 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 702 |
+
|
| 703 |
+
eval_steps=int(tr_cfg.get("eval_steps", 200)),
|
| 704 |
+
|
| 705 |
+
load_best_model_at_end=bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False,
|
| 706 |
+
metric_for_best_model="eval_loss",
|
| 707 |
+
greater_is_better=False,
|
| 708 |
+
|
| 709 |
+
fp16=use_fp16,
|
| 710 |
+
bf16=use_bf16,
|
| 711 |
+
|
| 712 |
+
report_to=report_to,
|
| 713 |
+
remove_unused_columns=False,
|
| 714 |
+
save_safetensors=True,
|
| 715 |
+
overwrite_output_dir=False,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Set the correct argument name for this transformers version
|
| 719 |
+
desired_ta_kwargs[eval_key] = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
|
| 720 |
+
ta_kwargs = {k: v for k, v in desired_ta_kwargs.items() if k in ta_params}
|
| 721 |
+
|
| 722 |
+
training_args = TrainingArguments(**ta_kwargs)
|
| 723 |
+
|
| 724 |
+
trainer_params = inspect.signature(Trainer.__init__).parameters
|
| 725 |
+
desired_trainer_kwargs = dict(
|
| 726 |
+
model=model,
|
| 727 |
+
args=training_args,
|
| 728 |
+
train_dataset=train_ds,
|
| 729 |
+
eval_dataset=eval_ds,
|
| 730 |
+
tokenizer=tokenizer,
|
| 731 |
+
processing_class=tokenizer,
|
| 732 |
+
data_collator=default_data_collator,
|
| 733 |
+
callbacks=[JsonlLoggerCallback(run_dir)],
|
| 734 |
+
)
|
| 735 |
+
trainer_kwargs = {k: v for k, v in desired_trainer_kwargs.items() if k in trainer_params}
|
| 736 |
+
trainer = Trainer(**trainer_kwargs)
|
| 737 |
+
|
| 738 |
+
# Resume
|
| 739 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 740 |
+
if resume_from == "auto":
|
| 741 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 742 |
+
resume_from = last if last else None
|
| 743 |
+
if resume_from:
|
| 744 |
+
print(f"Resuming from {resume_from}")
|
| 745 |
+
|
| 746 |
+
print("Starting training...")
|
| 747 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 748 |
+
|
| 749 |
+
trainer.save_model(str(best_adapter_dir))
|
| 750 |
+
print(f"Saved best adapter -> {best_adapter_dir}")
|
| 751 |
+
|
| 752 |
+
if eval_ds is not None:
|
| 753 |
+
metrics = trainer.evaluate()
|
| 754 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 755 |
+
metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 756 |
+
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
|
| 757 |
+
json.dump(metrics, f, indent=2)
|
| 758 |
+
print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
|
| 759 |
+
|
| 760 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 761 |
+
del trainer, model
|
| 762 |
+
torch.cuda.empty_cache()
|
| 763 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 764 |
+
else:
|
| 765 |
+
print("Merge disabled. Run with --merge-only later if needed.")
|
| 766 |
+
|
| 767 |
+
# Finish Wandb run
|
| 768 |
+
finish_wandb()
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
if __name__ == "__main__":
|
| 772 |
+
main()
|
trainer-kit/CPT/README.md
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Trainer‑Kit : Config‑Driven CPT (LoRA / QLoRA) with Packing, Logging, Resume, and Merge
|
| 2 |
+
|
| 3 |
+
Trainer‑Kit is a small, config‑driven training runner for **continued pretraining (CPT)** on causal LMs.
|
| 4 |
+
It supports **LoRA** and **QLoRA**, data **packing** (strict or padding‑masked), **checkpointing + resume**, **JSONL logging**, periodic **eval with perplexity**, and an optional **merge** step to export a final merged model.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## What we built
|
| 9 |
+
|
| 10 |
+
### ✅ Core goals implemented
|
| 11 |
+
|
| 12 |
+
* **CPT training loop** controlled entirely via a **YAML config**
|
| 13 |
+
* **Local model support** (load from filesystem) and optional **HF download** (if `repo_id` is a hub id)
|
| 14 |
+
* **JSONL datasets** for train (+ optional eval split)
|
| 15 |
+
* **CPT‑style token stream packing** into fixed‑length blocks
|
| 16 |
+
* **Two packing modes**
|
| 17 |
+
|
| 18 |
+
* `drop`: strict CPT, drop remainder tokens (preferred for real CPT)
|
| 19 |
+
* `pad`: pad the remainder to `block_size` and **mask loss** on padding (useful for small datasets / debugging)
|
| 20 |
+
* **Checkpointing + resume**
|
| 21 |
+
|
| 22 |
+
* `resume_from_checkpoint: "auto"` resumes from the latest checkpoint under `run_dir/checkpoints`
|
| 23 |
+
* **JSONL logs** written locally
|
| 24 |
+
|
| 25 |
+
* training logs: `run_dir/logs/train.jsonl`
|
| 26 |
+
* eval logs: `run_dir/logs/eval.jsonl`
|
| 27 |
+
* **Evaluation**
|
| 28 |
+
|
| 29 |
+
* logs `eval_loss` and computed `perplexity = exp(eval_loss)` (with safe overflow guard)
|
| 30 |
+
* **Adapter output**
|
| 31 |
+
|
| 32 |
+
* saves the final/best adapter to `run_dir/best_adapter`
|
| 33 |
+
* **Merge workflow**
|
| 34 |
+
|
| 35 |
+
* `--merge-only` merges an existing adapter later
|
| 36 |
+
* merge is done **on CPU** to avoid GPU OOM
|
| 37 |
+
* merged model is stored under the configured merge output directory (relative to `run_dir` if a relative path)
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## Repository layout (outputs)
|
| 42 |
+
|
| 43 |
+
A run produces the following structure under `run.run_dir`:
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
runs/<run_name>/
|
| 47 |
+
├─ checkpoints/ # trainer checkpoints (for resume)
|
| 48 |
+
├─ best_adapter/ # saved LoRA adapter
|
| 49 |
+
├─ logs/
|
| 50 |
+
│ ├─ train.jsonl # step-wise training logs
|
| 51 |
+
│ └─ eval.jsonl # eval logs (eval_loss + perplexity)
|
| 52 |
+
├─ eval_final.json # final eval metrics summary (if eval is enabled)
|
| 53 |
+
└─ config_resolved.yaml # exact config used for the run
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
If merge is used, the merged model is written to:
|
| 57 |
+
|
| 58 |
+
* `run_dir/<merge.output_dir>` if `merge.output_dir` is relative (e.g. `./merged_model`)
|
| 59 |
+
* or the absolute path if it is absolute.
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## Supported training modes
|
| 64 |
+
|
| 65 |
+
### 1) LoRA vs QLoRA (same script)
|
| 66 |
+
|
| 67 |
+
* **QLoRA** happens when `model.use_4bit: true`
|
| 68 |
+
|
| 69 |
+
* base weights are loaded in 4‑bit using bitsandbytes
|
| 70 |
+
* training updates only LoRA parameters
|
| 71 |
+
* **LoRA** happens when `model.use_4bit: false`
|
| 72 |
+
|
| 73 |
+
* base weights are loaded in fp16/bf16 (as configured)
|
| 74 |
+
* training updates only LoRA parameters
|
| 75 |
+
|
| 76 |
+
No “full finetune” mode is enabled by default in this runner.
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## Data pipeline (CPT behavior)
|
| 81 |
+
|
| 82 |
+
### Input format
|
| 83 |
+
|
| 84 |
+
* JSONL file where each line contains a text field (default `"text"`).
|
| 85 |
+
* Example:
|
| 86 |
+
|
| 87 |
+
* `{"text": "some training text..."}`
|
| 88 |
+
|
| 89 |
+
### Packing (token stream → fixed blocks)
|
| 90 |
+
|
| 91 |
+
* Each sample is tokenized without truncation.
|
| 92 |
+
* An **EOS token is appended** per document to preserve boundaries.
|
| 93 |
+
* Token lists are concatenated and converted into **fixed‑length blocks** of `data.block_size`.
|
| 94 |
+
|
| 95 |
+
Two modes:
|
| 96 |
+
|
| 97 |
+
* **`drop` (strict CPT):** remainder tokens that don’t fill a full block are discarded.
|
| 98 |
+
* **`pad` (debug/small data):** remainder is padded to block_size:
|
| 99 |
+
|
| 100 |
+
* `attention_mask = 0` for padded positions
|
| 101 |
+
* `labels = -100` for padded positions (loss masking)
|
| 102 |
+
|
| 103 |
+
This is what allowed training to proceed even with tiny dummy datasets at `block_size=1024`.
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
## Logging
|
| 108 |
+
|
| 109 |
+
Trainer‑Kit writes **machine‑readable logs** in JSONL.
|
| 110 |
+
|
| 111 |
+
### Training logs (`logs/train.jsonl`)
|
| 112 |
+
|
| 113 |
+
Includes entries with:
|
| 114 |
+
|
| 115 |
+
* `step`
|
| 116 |
+
* `loss`
|
| 117 |
+
* `grad_norm`
|
| 118 |
+
* `learning_rate`
|
| 119 |
+
* `progress_pct` (step progress when `max_steps` is active)
|
| 120 |
+
* ETA estimation
|
| 121 |
+
|
| 122 |
+
### Eval logs (`logs/eval.jsonl`)
|
| 123 |
+
|
| 124 |
+
Includes:
|
| 125 |
+
|
| 126 |
+
* `eval_loss`
|
| 127 |
+
* `perplexity`
|
| 128 |
+
|
| 129 |
+
Notes:
|
| 130 |
+
|
| 131 |
+
* When using `max_steps`, the Trainer’s internal `epoch` counter can grow unexpectedly on tiny datasets (because steps/epoch becomes ~1).
|
| 132 |
+
**Use `progress_pct` as the reliable indicator** for step‑based runs.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## Checkpointing and resume
|
| 137 |
+
|
| 138 |
+
The trainer saves checkpoints under:
|
| 139 |
+
|
| 140 |
+
* `run_dir/checkpoints/`
|
| 141 |
+
|
| 142 |
+
Resume options:
|
| 143 |
+
|
| 144 |
+
* `resume_from_checkpoint: "auto"` → picks the latest checkpoint automatically
|
| 145 |
+
* `resume_from_checkpoint: "/path/to/checkpoint"` → resumes from a specific checkpoint
|
| 146 |
+
* `resume_from_checkpoint: null` → fresh run
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## Merging adapters into a final model
|
| 151 |
+
|
| 152 |
+
Trainer‑Kit supports exporting a merged model:
|
| 153 |
+
|
| 154 |
+
### Merge after training
|
| 155 |
+
|
| 156 |
+
* Enable merge in config (`merge.enabled: true`)
|
| 157 |
+
* The script will:
|
| 158 |
+
|
| 159 |
+
1. save the adapter
|
| 160 |
+
2. free GPU memory
|
| 161 |
+
3. reload base model on **CPU**
|
| 162 |
+
4. load adapter
|
| 163 |
+
5. `merge_and_unload()`
|
| 164 |
+
6. save final merged model
|
| 165 |
+
|
| 166 |
+
### Merge later
|
| 167 |
+
|
| 168 |
+
Run:
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
python run_cpt.py --config config.yaml --merge-only
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
This skips training and merges `run_dir/best_adapter` into the base model.
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## How to run
|
| 179 |
+
|
| 180 |
+
### Train
|
| 181 |
+
|
| 182 |
+
```
|
| 183 |
+
python run_cpt.py --config config.yaml
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### Merge only
|
| 187 |
+
|
| 188 |
+
```
|
| 189 |
+
python run_cpt.py --config config.yaml --merge-only
|
trainer-kit/CPT/commands.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Commands
|
| 2 |
+
|
| 3 |
+
Train (no merge):
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
python run_cpt.py --config config.yaml
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
Merge later:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
python run_cpt.py --config config.yaml --merge-only
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
---
|
trainer-kit/CPT/config.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/cpt_run_v1"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
model:
|
| 6 |
+
# Local model path (no download)
|
| 7 |
+
repo_id: "/workspace/Models/Devstral-Small-2-24B-Instruct-2512"
|
| 8 |
+
revision: null
|
| 9 |
+
|
| 10 |
+
# Used only when repo_id is a HF repo (not a local path)
|
| 11 |
+
base_local_dir: "base_model"
|
| 12 |
+
|
| 13 |
+
trust_remote_code: true
|
| 14 |
+
tokenizer_use_fast: true
|
| 15 |
+
device_map: "auto"
|
| 16 |
+
|
| 17 |
+
torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
|
| 18 |
+
|
| 19 |
+
# QLoRA
|
| 20 |
+
use_4bit: false
|
| 21 |
+
bnb_4bit_quant_type: "nf4"
|
| 22 |
+
bnb_4bit_use_double_quant: false
|
| 23 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 24 |
+
|
| 25 |
+
# optional: "flash_attention_2" | "sdpa" | null
|
| 26 |
+
attn_implementation: null
|
| 27 |
+
|
| 28 |
+
data:
|
| 29 |
+
train_jsonl: "/workspace/all_data_with_descriptions.jsonl"
|
| 30 |
+
eval_jsonl: null
|
| 31 |
+
eval_split_ratio: 0.1
|
| 32 |
+
text_field: "text"
|
| 33 |
+
block_size: 4096
|
| 34 |
+
shuffle: true
|
| 35 |
+
num_proc: 4
|
| 36 |
+
|
| 37 |
+
# ✅ NEW: packing behavior
|
| 38 |
+
# "drop" = strict CPT (drop remainder)
|
| 39 |
+
# "pad" = pad remainder to block_size + loss mask (-100) + attention_mask=0
|
| 40 |
+
pack_mode: "pad"
|
| 41 |
+
|
| 42 |
+
peft:
|
| 43 |
+
enabled: true
|
| 44 |
+
r: 64
|
| 45 |
+
lora_alpha: 128
|
| 46 |
+
lora_dropout: 0.05
|
| 47 |
+
bias: "none"
|
| 48 |
+
target_modules: "auto"
|
| 49 |
+
|
| 50 |
+
train:
|
| 51 |
+
#max_steps: 1000
|
| 52 |
+
num_train_epochs: 2
|
| 53 |
+
|
| 54 |
+
per_device_train_batch_size: 1
|
| 55 |
+
per_device_eval_batch_size: 1
|
| 56 |
+
gradient_accumulation_steps: 16
|
| 57 |
+
|
| 58 |
+
learning_rate: 2e-5
|
| 59 |
+
weight_decay: 0.0
|
| 60 |
+
warmup_ratio: 0.1
|
| 61 |
+
lr_scheduler_type: "cosine"
|
| 62 |
+
|
| 63 |
+
optim: "paged_adamw_8bit"
|
| 64 |
+
max_grad_norm: 1.0
|
| 65 |
+
gradient_checkpointing: true
|
| 66 |
+
|
| 67 |
+
logging_steps: 1
|
| 68 |
+
save_strategy: "steps"
|
| 69 |
+
save_steps: 100
|
| 70 |
+
save_total_limit: 4
|
| 71 |
+
|
| 72 |
+
evaluation_strategy: "steps"
|
| 73 |
+
eval_steps: 50
|
| 74 |
+
load_best_model_at_end: true
|
| 75 |
+
|
| 76 |
+
resume_from_checkpoint: "auto"
|
| 77 |
+
|
| 78 |
+
merge:
|
| 79 |
+
enabled: true
|
| 80 |
+
merged_dtype: "float16"
|
| 81 |
+
max_shard_size: "2GB"
|
| 82 |
+
output_dir: "./merged_24b_cpt_lora"
|
trainer-kit/CPT/detailed_parameter_documentation.md
ADDED
|
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CPT Configuration Parameters: Detailed Guide
|
| 2 |
+
|
| 3 |
+
This document provides a comprehensive explanation of all configuration parameters in `config.yaml` and how they're implemented in `run_cpt.py`.
|
| 4 |
+
|
| 5 |
+
## Table of Contents
|
| 6 |
+
- [Run Parameters](#run-parameters)
|
| 7 |
+
- [Model Parameters](#model-parameters)
|
| 8 |
+
- [Data Parameters](#data-parameters)
|
| 9 |
+
- [PEFT Parameters](#peft-parameters)
|
| 10 |
+
- [Training Parameters](#training-parameters)
|
| 11 |
+
- [Merge Parameters](#merge-parameters)
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Run Parameters
|
| 16 |
+
|
| 17 |
+
### `run.run_dir`
|
| 18 |
+
- **Type**: String (path)
|
| 19 |
+
- **Required**: Yes
|
| 20 |
+
- **Default**: No default
|
| 21 |
+
- **Description**: Directory where training outputs will be saved
|
| 22 |
+
- **Used in**: Line ~480 in `run_cpt.py`
|
| 23 |
+
- **Implementation**:
|
| 24 |
+
```python
|
| 25 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 26 |
+
```
|
| 27 |
+
- **Example Values**:
|
| 28 |
+
- `./runs/cpt_run_v1`
|
| 29 |
+
- `/workspace/outputs/my_experiment`
|
| 30 |
+
- `./checkpoints/cpt_experiment`
|
| 31 |
+
|
| 32 |
+
### `run.seed`
|
| 33 |
+
- **Type**: Integer
|
| 34 |
+
- **Required**: No
|
| 35 |
+
- **Default**: None
|
| 36 |
+
- **Description**: Random seed for reproducibility
|
| 37 |
+
- **Used in**: Lines ~460, ~240 in `run_cpt.py`
|
| 38 |
+
- **Implementation**:
|
| 39 |
+
```python
|
| 40 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 41 |
+
# Used in data shuffling and train/test split
|
| 42 |
+
```
|
| 43 |
+
- **Example Values**: `42`, `123`, `2023`
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## Model Parameters
|
| 48 |
+
|
| 49 |
+
### `model.repo_id`
|
| 50 |
+
- **Type**: String (path or HuggingFace repo)
|
| 51 |
+
- **Required**: Yes
|
| 52 |
+
- **Default**: No default
|
| 53 |
+
- **Description**: Model identifier - can be local path or HuggingFace repository
|
| 54 |
+
- **Used in**: Lines ~480-500 in `run_cpt.py`
|
| 55 |
+
- **Implementation**:
|
| 56 |
+
```python
|
| 57 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 58 |
+
repo_path = Path(repo_id)
|
| 59 |
+
if repo_path.exists() and repo_path.is_dir():
|
| 60 |
+
base_dir = repo_path # Local path
|
| 61 |
+
else:
|
| 62 |
+
# Download from HuggingFace
|
| 63 |
+
snapshot_download(repo_id=repo_id, ...)
|
| 64 |
+
```
|
| 65 |
+
- **Example Values**:
|
| 66 |
+
- Local: `/workspace/Models/Devstral-Small-2-24B-Instruct-2512`
|
| 67 |
+
- HF Repo: `meta-llama/Llama-2-7b-hf`
|
| 68 |
+
|
| 69 |
+
### `model.revision`
|
| 70 |
+
- **Type**: String or null
|
| 71 |
+
- **Required**: No
|
| 72 |
+
- **Default**: null
|
| 73 |
+
- **Description**: Specific model revision/branch/tag from HuggingFace
|
| 74 |
+
- **Used in**: Line ~495 in `run_cpt.py`
|
| 75 |
+
- **Implementation**:
|
| 76 |
+
```python
|
| 77 |
+
snapshot_download(
|
| 78 |
+
repo_id=repo_id,
|
| 79 |
+
revision=model_cfg.get("revision", None),
|
| 80 |
+
...
|
| 81 |
+
)
|
| 82 |
+
```
|
| 83 |
+
- **Example Values**:
|
| 84 |
+
- `"main"` - Main branch
|
| 85 |
+
- `"v1.0"` - Specific tag
|
| 86 |
+
- `"abc123def"` - Specific commit hash
|
| 87 |
+
- `null` - Latest version
|
| 88 |
+
|
| 89 |
+
### `model.base_local_dir`
|
| 90 |
+
- **Type**: String (path)
|
| 91 |
+
- **Required**: No
|
| 92 |
+
- **Default**: `"base_model"`
|
| 93 |
+
- **Description**: Directory name for downloaded model when using HF repo
|
| 94 |
+
- **Used in**: Line ~495 in `run_cpt.py`
|
| 95 |
+
- **Implementation**:
|
| 96 |
+
```python
|
| 97 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 98 |
+
```
|
| 99 |
+
- **Example Values**: `"base_model"`, `"downloaded_model"`, `"model_files"`
|
| 100 |
+
|
| 101 |
+
### `model.trust_remote_code`
|
| 102 |
+
- **Type**: Boolean
|
| 103 |
+
- **Required**: No
|
| 104 |
+
- **Default**: `true`
|
| 105 |
+
- **Description**: Allow loading models with custom code
|
| 106 |
+
- **Used in**: Lines ~320, ~340, ~450 in `run_cpt.py`
|
| 107 |
+
- **Implementation**:
|
| 108 |
+
```python
|
| 109 |
+
tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
|
| 110 |
+
model = AutoModelForCausalLM.from_pretrained(..., trust_remote_code=trust_remote_code, ...)
|
| 111 |
+
```
|
| 112 |
+
- **Example Values**: `true`, `false`
|
| 113 |
+
|
| 114 |
+
### `model.tokenizer_use_fast`
|
| 115 |
+
- **Type**: Boolean
|
| 116 |
+
- **Required**: No
|
| 117 |
+
- **Default**: `true`
|
| 118 |
+
- **Description**: Use fast tokenizer implementation
|
| 119 |
+
- **Used in**: Lines ~320, ~450 in `run_cpt.py`
|
| 120 |
+
- **Implementation**:
|
| 121 |
+
```python
|
| 122 |
+
tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
|
| 123 |
+
```
|
| 124 |
+
- **Example Values**: `true`, `false`
|
| 125 |
+
|
| 126 |
+
### `model.device_map`
|
| 127 |
+
- **Type**: String
|
| 128 |
+
- **Required**: No
|
| 129 |
+
- **Default**: `"auto"`
|
| 130 |
+
- **Description**: How to distribute model across devices
|
| 131 |
+
- **Used in**: Lines ~350, ~370 in `run_cpt.py`
|
| 132 |
+
- **Implementation**:
|
| 133 |
+
```python
|
| 134 |
+
model = AutoModelForCausalLM.from_pretrained(..., device_map=device_map, ...)
|
| 135 |
+
```
|
| 136 |
+
- **Example Values**:
|
| 137 |
+
- `"auto"` - Automatic distribution
|
| 138 |
+
- `"cpu"` - CPU only
|
| 139 |
+
- `"cuda:0"` - Single GPU
|
| 140 |
+
- `{"": 0}` - Manual mapping
|
| 141 |
+
|
| 142 |
+
### `model.torch_dtype`
|
| 143 |
+
- **Type**: String
|
| 144 |
+
- **Required**: No
|
| 145 |
+
- **Default**: `"bfloat16"`
|
| 146 |
+
- **Description**: Data type for model tensors
|
| 147 |
+
- **Used in**: Lines ~45, ~350 in `run_cpt.py`
|
| 148 |
+
- **Implementation**:
|
| 149 |
+
```python
|
| 150 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 151 |
+
if s in ("float16", "fp16"): return torch.float16
|
| 152 |
+
if s in ("bfloat16", "bf16"): return torch.bfloat16
|
| 153 |
+
if s in ("float32", "fp32"): return torch.float32
|
| 154 |
+
```
|
| 155 |
+
- **Example Values**:
|
| 156 |
+
- `"float16"` - 16-bit floats (faster, less memory, less stable)
|
| 157 |
+
- `"bfloat16"` - Brain float16 (stable, good for training)
|
| 158 |
+
- `"float32"` - 32-bit floats (slowest, most memory)
|
| 159 |
+
|
| 160 |
+
### `model.use_4bit`
|
| 161 |
+
- **Type**: Boolean
|
| 162 |
+
- **Required**: No
|
| 163 |
+
- **Default**: `false`
|
| 164 |
+
- **Description**: Use 4-bit quantization for memory efficiency
|
| 165 |
+
- **Used in**: Lines ~325, ~395 in `run_cpt.py`
|
| 166 |
+
- **Implementation**:
|
| 167 |
+
```python
|
| 168 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 169 |
+
if use_4bit:
|
| 170 |
+
quant_cfg = BitsAndBytesConfig(load_in_4bit=True, ...)
|
| 171 |
+
```
|
| 172 |
+
- **Example Values**: `true`, `false`
|
| 173 |
+
|
| 174 |
+
### `model.bnb_4bit_quant_type`
|
| 175 |
+
- **Type**: String
|
| 176 |
+
- **Required**: No
|
| 177 |
+
- **Default**: `"nf4"`
|
| 178 |
+
- **Description**: 4-bit quantization type
|
| 179 |
+
- **Used in**: Lines ~328 in `run_cpt.py`
|
| 180 |
+
- **Implementation**:
|
| 181 |
+
```python
|
| 182 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4"))
|
| 183 |
+
```
|
| 184 |
+
- **Example Values**:
|
| 185 |
+
- `"nf4"` - NormalFloat4 (recommended)
|
| 186 |
+
- `"fp4"` - FloatingPoint4
|
| 187 |
+
- `"int4"` - Integer4
|
| 188 |
+
|
| 189 |
+
### `model.bnb_4bit_use_double_quant`
|
| 190 |
+
- **Type**: Boolean
|
| 191 |
+
- **Required**: No
|
| 192 |
+
- **Default**: `false`
|
| 193 |
+
- **Description**: Use double quantization for memory efficiency
|
| 194 |
+
- **Used in**: Lines ~329 in `run_cpt.py`
|
| 195 |
+
- **Implementation**:
|
| 196 |
+
```python
|
| 197 |
+
bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True))
|
| 198 |
+
```
|
| 199 |
+
- **Example Values**: `true`, `false`
|
| 200 |
+
|
| 201 |
+
### `model.bnb_4bit_compute_dtype`
|
| 202 |
+
- **Type**: String
|
| 203 |
+
- **Required**: No
|
| 204 |
+
- **Default**: `"bfloat16"`
|
| 205 |
+
- **Description**: Compute dtype for 4-bit quantization
|
| 206 |
+
- **Used in**: Lines ~330 in `run_cpt.py`
|
| 207 |
+
- **Implementation**:
|
| 208 |
+
```python
|
| 209 |
+
bnb_4bit_compute_dtype=_dtype_from_str(model_cfg.get("bnb_4bit_compute_dtype", "bfloat16"))
|
| 210 |
+
```
|
| 211 |
+
- **Example Values**: `"float16"`, `"bfloat16"`, `"float32"`
|
| 212 |
+
|
| 213 |
+
### `model.attn_implementation`
|
| 214 |
+
- **Type**: String or null
|
| 215 |
+
- **Required**: No
|
| 216 |
+
- **Default**: `null`
|
| 217 |
+
- **Description**: Attention implementation to use
|
| 218 |
+
- **Used in**: Lines ~155, ~350 in `run_cpt.py`
|
| 219 |
+
- **Implementation**:
|
| 220 |
+
```python
|
| 221 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 222 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 223 |
+
# Used in model.from_pretrained(..., attn_implementation=attn_impl, ...)
|
| 224 |
+
```
|
| 225 |
+
- **Example Values**:
|
| 226 |
+
- `"flash_attention_2"` - Flash Attention 2 (fastest)
|
| 227 |
+
- `"sdpa"` - Scaled Dot-Product Attention
|
| 228 |
+
- `null` - Default implementation
|
| 229 |
+
|
| 230 |
+
---
|
| 231 |
+
|
| 232 |
+
## Data Parameters
|
| 233 |
+
|
| 234 |
+
### `data.train_jsonl`
|
| 235 |
+
- **Type**: String (path)
|
| 236 |
+
- **Required**: Yes
|
| 237 |
+
- **Default**: No default
|
| 238 |
+
- **Description**: Path to training data in JSONL format
|
| 239 |
+
- **Used in**: Lines ~170 in `run_cpt.py`
|
| 240 |
+
- **Implementation**:
|
| 241 |
+
```python
|
| 242 |
+
train_path = data_cfg["train_jsonl"]
|
| 243 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 244 |
+
```
|
| 245 |
+
- **Example Values**: `"/workspace/all_data_with_descriptions.jsonl"`
|
| 246 |
+
|
| 247 |
+
### `data.eval_jsonl`
|
| 248 |
+
- **Type**: String (path) or null
|
| 249 |
+
- **Required**: No
|
| 250 |
+
- **Default**: `null`
|
| 251 |
+
- **Description**: Path to evaluation data in JSONL format
|
| 252 |
+
- **Used in**: Lines ~175 in `run_cpt.py`
|
| 253 |
+
- **Implementation**:
|
| 254 |
+
```python
|
| 255 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 256 |
+
if eval_path:
|
| 257 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 258 |
+
```
|
| 259 |
+
- **Example Values**: `null` (no separate eval file), `"/workspace/eval_data.jsonl"`
|
| 260 |
+
|
| 261 |
+
### `data.eval_split_ratio`
|
| 262 |
+
- **Type**: Float
|
| 263 |
+
- **Required**: No
|
| 264 |
+
- **Default**: `0.1`
|
| 265 |
+
- **Description**: Ratio of training data to use for evaluation split
|
| 266 |
+
- **Used in**: Lines ~177 in `run_cpt.py`
|
| 267 |
+
- **Implementation**:
|
| 268 |
+
```python
|
| 269 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 270 |
+
if 0.0 < split_ratio < 1.0:
|
| 271 |
+
split = ds["train"].train_test_split(test_size=split_ratio, seed=seed)
|
| 272 |
+
```
|
| 273 |
+
- **Example Values**: `0.1` (10%), `0.2` (20%), `0.05` (5%)
|
| 274 |
+
|
| 275 |
+
### `data.text_field`
|
| 276 |
+
- **Type**: String
|
| 277 |
+
- **Required**: No
|
| 278 |
+
- **Default**: `"text"`
|
| 279 |
+
- **Description**: Field name in JSONL containing the text data
|
| 280 |
+
- **Used in**: Lines ~185 in `run_cpt.py`
|
| 281 |
+
- **Implementation**:
|
| 282 |
+
```python
|
| 283 |
+
text_field = data_cfg.get("text_field", "text")
|
| 284 |
+
# Used in tokenization
|
| 285 |
+
tokenized = dsd["train"].map(
|
| 286 |
+
tokenize_fn,
|
| 287 |
+
batched=True,
|
| 288 |
+
remove_columns=dsd["train"].column_names,
|
| 289 |
+
desc="Tokenizing train",
|
| 290 |
+
)
|
| 291 |
+
```
|
| 292 |
+
- **Example Values**: `"text"`, `"content"`, `"prompt"`, `"input"`
|
| 293 |
+
|
| 294 |
+
### `data.block_size`
|
| 295 |
+
- **Type**: Integer
|
| 296 |
+
- **Required**: No
|
| 297 |
+
- **Default**: `4096`
|
| 298 |
+
- **Description**: Maximum sequence length for training
|
| 299 |
+
- **Used in**: Lines ~180 in `run_cpt.py`
|
| 300 |
+
- **Implementation**:
|
| 301 |
+
```python
|
| 302 |
+
block_size = int(data_cfg.get("block_size", 2048))
|
| 303 |
+
# Used in grouping texts into blocks
|
| 304 |
+
for i in range(0, full_len, block_size):
|
| 305 |
+
chunk = concatenated["input_ids"][i:i + block_size]
|
| 306 |
+
```
|
| 307 |
+
- **Example Values**: `2048`, `4096`, `8192`
|
| 308 |
+
|
| 309 |
+
### `data.shuffle`
|
| 310 |
+
- **Type**: Boolean
|
| 311 |
+
- **Required**: No
|
| 312 |
+
- **Default**: `true`
|
| 313 |
+
- **Description**: Whether to shuffle training data
|
| 314 |
+
- **Used in**: Lines ~235 in `run_cpt.py`
|
| 315 |
+
- **Implementation**:
|
| 316 |
+
```python
|
| 317 |
+
if shuffle:
|
| 318 |
+
tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 319 |
+
```
|
| 320 |
+
- **Example Values**: `true`, `false`
|
| 321 |
+
|
| 322 |
+
### `data.num_proc`
|
| 323 |
+
- **Type**: Integer
|
| 324 |
+
- **Required**: No
|
| 325 |
+
- **Default**: `4`
|
| 326 |
+
- **Description**: Number of processes for data loading
|
| 327 |
+
- **Used in**: Lines ~200, ~210 in `run_cpt.py`
|
| 328 |
+
- **Implementation**:
|
| 329 |
+
```python
|
| 330 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 331 |
+
tokenized_train = dsd["train"].map(
|
| 332 |
+
tokenize_fn,
|
| 333 |
+
batched=True,
|
| 334 |
+
num_proc=num_proc,
|
| 335 |
+
...
|
| 336 |
+
)
|
| 337 |
+
```
|
| 338 |
+
- **Example Values**: `1`, `4`, `8`, `16`
|
| 339 |
+
|
| 340 |
+
### `data.pack_mode`
|
| 341 |
+
- **Type**: String
|
| 342 |
+
- **Required**: No
|
| 343 |
+
- **Default**: `"pad"`
|
| 344 |
+
- **Description**: How to handle remainder tokens in final block
|
| 345 |
+
- **Used in**: Lines ~150-230 in `run_cpt.py`
|
| 346 |
+
- **Implementation**:
|
| 347 |
+
```python
|
| 348 |
+
pack_mode = str(data_cfg.get("pack_mode", "drop")).lower().strip()
|
| 349 |
+
if pack_mode == "pad":
|
| 350 |
+
# Pad remainder and mask loss
|
| 351 |
+
labels[-pad_len:] = [-100] * pad_len
|
| 352 |
+
# If "drop": ignore remainder entirely
|
| 353 |
+
```
|
| 354 |
+
- **Example Values**:
|
| 355 |
+
- `"drop"` - Drop incomplete blocks (strict CPT)
|
| 356 |
+
- `"pad"` - Pad incomplete blocks with masked loss
|
| 357 |
+
|
| 358 |
+
---
|
| 359 |
+
|
| 360 |
+
## PEFT Parameters
|
| 361 |
+
|
| 362 |
+
### `peft.enabled`
|
| 363 |
+
- **Type**: Boolean
|
| 364 |
+
- **Required**: No
|
| 365 |
+
- **Default**: `true`
|
| 366 |
+
- **Description**: Whether to use PEFT (Parameter-Efficient Fine-Tuning)
|
| 367 |
+
- **Used in**: Lines ~395 in `run_cpt.py`
|
| 368 |
+
- **Implementation**:
|
| 369 |
+
```python
|
| 370 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 371 |
+
return model, None
|
| 372 |
+
# Otherwise proceed with LoRA configuration
|
| 373 |
+
```
|
| 374 |
+
- **Example Values**: `true`, `false`
|
| 375 |
+
|
| 376 |
+
### `peft.r`
|
| 377 |
+
- **Type**: Integer
|
| 378 |
+
- **Required**: No
|
| 379 |
+
- **Default**: `64`
|
| 380 |
+
- **Description**: LoRA rank - dimension of low-rank matrices
|
| 381 |
+
- **Used in**: Lines ~415 in `run_cpt.py`
|
| 382 |
+
- **Implementation**:
|
| 383 |
+
```python
|
| 384 |
+
lora_config = LoraConfig(
|
| 385 |
+
r=int(peft_cfg.get("r", 16)),
|
| 386 |
+
...
|
| 387 |
+
)
|
| 388 |
+
```
|
| 389 |
+
- **Example Values**: `8`, `16`, `32`, `64`, `128`
|
| 390 |
+
- **Note**: Higher values = more parameters but potentially better performance
|
| 391 |
+
|
| 392 |
+
### `peft.lora_alpha`
|
| 393 |
+
- **Type**: Integer
|
| 394 |
+
- **Required**: No
|
| 395 |
+
- **Default**: `128`
|
| 396 |
+
- **Description**: LoRA alpha scaling parameter
|
| 397 |
+
- **Used in**: Lines ~416 in `run_cpt.py`
|
| 398 |
+
- **Implementation**:
|
| 399 |
+
```python
|
| 400 |
+
lora_config = LoraConfig(
|
| 401 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 402 |
+
...
|
| 403 |
+
)
|
| 404 |
+
```
|
| 405 |
+
- **Example Values**: `16`, `32`, `64`, `128`, `256`
|
| 406 |
+
|
| 407 |
+
### `peft.lora_dropout`
|
| 408 |
+
- **Type**: Float
|
| 409 |
+
- **Required**: No
|
| 410 |
+
- **Default**: `0.05`
|
| 411 |
+
- **Description**: Dropout rate for LoRA layers
|
| 412 |
+
- **Used in**: Lines ~417 in `run_cpt.py`
|
| 413 |
+
- **Implementation**:
|
| 414 |
+
```python
|
| 415 |
+
lora_config = LoraConfig(
|
| 416 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 417 |
+
...
|
| 418 |
+
)
|
| 419 |
+
```
|
| 420 |
+
- **Example Values**: `0.0`, `0.05`, `0.1`, `0.2`
|
| 421 |
+
|
| 422 |
+
### `peft.bias`
|
| 423 |
+
- **Type**: String
|
| 424 |
+
- **Required**: No
|
| 425 |
+
- **Default**: `"none"`
|
| 426 |
+
- **Description**: Bias training strategy
|
| 427 |
+
- **Used in**: Lines ~418 in `run_cpt.py`
|
| 428 |
+
- **Implementation**:
|
| 429 |
+
```python
|
| 430 |
+
lora_config = LoraConfig(
|
| 431 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 432 |
+
...
|
| 433 |
+
)
|
| 434 |
+
```
|
| 435 |
+
- **Example Values**:
|
| 436 |
+
- `"none"` - No bias training
|
| 437 |
+
- `"all"` - Train all biases
|
| 438 |
+
- `"lora_only"` - Only LoRA bias
|
| 439 |
+
|
| 440 |
+
### `peft.target_modules`
|
| 441 |
+
- **Type**: String or List
|
| 442 |
+
- **Required**: No
|
| 443 |
+
- **Default**: `"auto"`
|
| 444 |
+
- **Description**: Which modules to apply LoRA to
|
| 445 |
+
- **Used in**: Lines ~405, ~140-170 in `run_cpt.py`
|
| 446 |
+
- **Implementation**:
|
| 447 |
+
```python
|
| 448 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 449 |
+
if target_modules == "auto":
|
| 450 |
+
target_modules = _infer_target_modules(model)
|
| 451 |
+
```
|
| 452 |
+
- **Example Values**:
|
| 453 |
+
- `"auto"` - Automatic detection
|
| 454 |
+
- `["q_proj", "k_proj", "v_proj", "o_proj"]` - Explicit list
|
| 455 |
+
- `["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]` - MLP only
|
| 456 |
+
|
| 457 |
+
---
|
| 458 |
+
|
| 459 |
+
## Training Parameters
|
| 460 |
+
|
| 461 |
+
### `train.num_train_epochs`
|
| 462 |
+
- **Type**: Float
|
| 463 |
+
- **Required**: No
|
| 464 |
+
- **Default**: `2`
|
| 465 |
+
- **Description**: Number of epochs to train
|
| 466 |
+
- **Used in**: Lines ~470 in `run_cpt.py`
|
| 467 |
+
- **Implementation**:
|
| 468 |
+
```python
|
| 469 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 470 |
+
# Used in TrainingArguments
|
| 471 |
+
```
|
| 472 |
+
- **Example Values**: `1.0`, `2.0`, `3.5`
|
| 473 |
+
|
| 474 |
+
### `train.per_device_train_batch_size`
|
| 475 |
+
- **Type**: Integer
|
| 476 |
+
- **Required**: No
|
| 477 |
+
- **Default**: `1`
|
| 478 |
+
- **Description**: Training batch size per device
|
| 479 |
+
- **Used in**: Lines ~475 in `run_cpt.py`
|
| 480 |
+
- **Implementation**:
|
| 481 |
+
```python
|
| 482 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1))
|
| 483 |
+
```
|
| 484 |
+
- **Example Values**: `1`, `2`, `4`, `8`
|
| 485 |
+
|
| 486 |
+
### `train.per_device_eval_batch_size`
|
| 487 |
+
- **Type**: Integer
|
| 488 |
+
- **Required**: No
|
| 489 |
+
- **Default**: Same as train batch size
|
| 490 |
+
- **Description**: Evaluation batch size per device
|
| 491 |
+
- **Used in**: Lines ~476 in `run_cpt.py`
|
| 492 |
+
- **Implementation**:
|
| 493 |
+
```python
|
| 494 |
+
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1)))
|
| 495 |
+
```
|
| 496 |
+
- **Example Values**: `1`, `2`, `4`, `8`
|
| 497 |
+
|
| 498 |
+
### `train.gradient_accumulation_steps`
|
| 499 |
+
- **Type**: Integer
|
| 500 |
+
- **Required**: No
|
| 501 |
+
- **Default**: `16`
|
| 502 |
+
- **Description**: Number of steps to accumulate gradients
|
| 503 |
+
- **Used in**: Lines ~477 in `run_cpt.py`
|
| 504 |
+
- **Implementation**:
|
| 505 |
+
```python
|
| 506 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1))
|
| 507 |
+
```
|
| 508 |
+
- **Example Values**: `1`, `4`, `8`, `16`, `32`
|
| 509 |
+
|
| 510 |
+
### `train.learning_rate`
|
| 511 |
+
- **Type**: Float
|
| 512 |
+
- **Required**: No
|
| 513 |
+
- **Default**: `2e-5`
|
| 514 |
+
- **Description**: Learning rate for optimizer
|
| 515 |
+
- **Used in**: Lines ~478 in `run_cpt.py`
|
| 516 |
+
- **Implementation**:
|
| 517 |
+
```python
|
| 518 |
+
learning_rate=float(tr_cfg.get("learning_rate", 2e-5))
|
| 519 |
+
```
|
| 520 |
+
- **Example Values**: `1e-5`, `2e-5`, `5e-5`, `1e-4`
|
| 521 |
+
|
| 522 |
+
### `train.weight_decay`
|
| 523 |
+
- **Type**: Float
|
| 524 |
+
- **Required**: No
|
| 525 |
+
- **Default**: `0.0`
|
| 526 |
+
- **Description**: Weight decay for regularization
|
| 527 |
+
- **Used in**: Lines ~479 in `run_cpt.py`
|
| 528 |
+
- **Implementation**:
|
| 529 |
+
```python
|
| 530 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0))
|
| 531 |
+
```
|
| 532 |
+
- **Example Values**: `0.0`, `0.01`, `0.1`
|
| 533 |
+
|
| 534 |
+
### `train.warmup_ratio`
|
| 535 |
+
- **Type**: Float
|
| 536 |
+
- **Required**: No
|
| 537 |
+
- **Default**: `0.1`
|
| 538 |
+
- **Description**: Ratio of steps for learning rate warmup
|
| 539 |
+
- **Used in**: Lines ~480 in `run_cpt.py`
|
| 540 |
+
- **Implementation**:
|
| 541 |
+
```python
|
| 542 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0))
|
| 543 |
+
```
|
| 544 |
+
- **Example Values**: `0.0`, `0.1`, `0.2`
|
| 545 |
+
|
| 546 |
+
### `train.lr_scheduler_type`
|
| 547 |
+
- **Type**: String
|
| 548 |
+
- **Required**: No
|
| 549 |
+
- **Default**: `"cosine"`
|
| 550 |
+
- **Description**: Learning rate scheduler type
|
| 551 |
+
- **Used in**: Lines ~481 in `run_cpt.py`
|
| 552 |
+
- **Implementation**:
|
| 553 |
+
```python
|
| 554 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine"))
|
| 555 |
+
```
|
| 556 |
+
- **Example Values**:
|
| 557 |
+
- `"cosine"` - Cosine annealing
|
| 558 |
+
- `"linear"` - Linear decay
|
| 559 |
+
- `"constant"` - Constant rate
|
| 560 |
+
- `"polynomial"` - Polynomial decay
|
| 561 |
+
|
| 562 |
+
### `train.optim`
|
| 563 |
+
- **Type**: String
|
| 564 |
+
- **Required**: No
|
| 565 |
+
- **Default**: `"paged_adamw_8bit"` (if 4-bit), `"adamw_torch"` (otherwise)
|
| 566 |
+
- **Description**: Optimizer type
|
| 567 |
+
- **Used in**: Lines ~482 in `run_cpt.py`
|
| 568 |
+
- **Implementation**:
|
| 569 |
+
```python
|
| 570 |
+
optim=str(tr_cfg.get("optim", "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch"))
|
| 571 |
+
```
|
| 572 |
+
- **Example Values**:
|
| 573 |
+
- `"adamw_torch"` - AdamW (standard)
|
| 574 |
+
- `"paged_adamw_8bit"` - Paged AdamW for 8-bit training
|
| 575 |
+
- `"sgd"` - SGD
|
| 576 |
+
- `"adafactor"` - Adafactor
|
| 577 |
+
|
| 578 |
+
### `train.max_grad_norm`
|
| 579 |
+
- **Type**: Float
|
| 580 |
+
- **Required**: No
|
| 581 |
+
- **Default**: `1.0`
|
| 582 |
+
- **Description**: Maximum gradient norm for clipping
|
| 583 |
+
- **Used in**: Lines ~483 in `run_cpt.py`
|
| 584 |
+
- **Implementation**:
|
| 585 |
+
```python
|
| 586 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0))
|
| 587 |
+
```
|
| 588 |
+
- **Example Values**: `0.5`, `1.0`, `2.0`
|
| 589 |
+
|
| 590 |
+
### `train.gradient_checkpointing`
|
| 591 |
+
- **Type**: Boolean
|
| 592 |
+
- **Required**: No
|
| 593 |
+
- **Default**: `true`
|
| 594 |
+
- **Description**: Use gradient checkpointing to save memory
|
| 595 |
+
- **Used in**: Lines ~396-400 in `run_cpt.py`
|
| 596 |
+
- **Implementation**:
|
| 597 |
+
```python
|
| 598 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 599 |
+
if gradient_checkpointing:
|
| 600 |
+
model.gradient_checkpointing_enable()
|
| 601 |
+
```
|
| 602 |
+
- **Example Values**: `true`, `false`
|
| 603 |
+
|
| 604 |
+
### `train.logging_steps`
|
| 605 |
+
- **Type**: Integer
|
| 606 |
+
- **Required**: No
|
| 607 |
+
- **Default**: `1`
|
| 608 |
+
- **Description**: Log training progress every N steps
|
| 609 |
+
- **Used in**: Lines ~485 in `run_cpt.py`
|
| 610 |
+
- **Implementation**:
|
| 611 |
+
```python
|
| 612 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10))
|
| 613 |
+
```
|
| 614 |
+
- **Example Values**: `1`, `10`, `50`, `100`
|
| 615 |
+
|
| 616 |
+
### `train.save_strategy`
|
| 617 |
+
- **Type**: String
|
| 618 |
+
- **Required**: No
|
| 619 |
+
- **Default**: `"steps"`
|
| 620 |
+
- **Description**: When to save model checkpoints
|
| 621 |
+
- **Used in**: Lines ~487 in `run_cpt.py`
|
| 622 |
+
- **Implementation**:
|
| 623 |
+
```python
|
| 624 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps"))
|
| 625 |
+
```
|
| 626 |
+
- **Example Values**:
|
| 627 |
+
- `"steps"` - Save every N steps
|
| 628 |
+
- `"epochs"` - Save every epoch
|
| 629 |
+
- `"no"` - Don't save
|
| 630 |
+
|
| 631 |
+
### `train.save_steps`
|
| 632 |
+
- **Type**: Integer
|
| 633 |
+
- **Required**: No
|
| 634 |
+
- **Default**: `100`
|
| 635 |
+
- **Description**: Save checkpoint every N steps
|
| 636 |
+
- **Used in**: Lines ~488 in `run_cpt.py`
|
| 637 |
+
- **Implementation**:
|
| 638 |
+
```python
|
| 639 |
+
save_steps=int(tr_cfg.get("save_steps", 200))
|
| 640 |
+
```
|
| 641 |
+
- **Example Values**: `50`, `100`, `200`, `500`
|
| 642 |
+
|
| 643 |
+
### `train.save_total_limit`
|
| 644 |
+
- **Type**: Integer
|
| 645 |
+
- **Required**: No
|
| 646 |
+
- **Default**: `4`
|
| 647 |
+
- **Description**: Maximum number of checkpoints to keep
|
| 648 |
+
- **Used in**: Lines ~489 in `run_cpt.py`
|
| 649 |
+
- **Implementation**:
|
| 650 |
+
```python
|
| 651 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3))
|
| 652 |
+
```
|
| 653 |
+
- **Example Values**: `1`, `2`, `3`, `5`
|
| 654 |
+
|
| 655 |
+
### `train.evaluation_strategy`
|
| 656 |
+
- **Type**: String
|
| 657 |
+
- **Required**: No
|
| 658 |
+
- **Default**: `"steps"` (if eval data), `"no"` (otherwise)
|
| 659 |
+
- **Description**: When to evaluate model
|
| 660 |
+
- **Used in**: Lines ~494 in `run_cpt.py`
|
| 661 |
+
- **Implementation**:
|
| 662 |
+
```python
|
| 663 |
+
evaluation_strategy=str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
|
| 664 |
+
```
|
| 665 |
+
- **Example Values**:
|
| 666 |
+
- `"steps"` - Evaluate every N steps
|
| 667 |
+
- `"epochs"` - Evaluate every epoch
|
| 668 |
+
- `"no"` - Don't evaluate
|
| 669 |
+
|
| 670 |
+
### `train.eval_steps`
|
| 671 |
+
- **Type**: Integer
|
| 672 |
+
- **Required**: No
|
| 673 |
+
- **Default**: `50`
|
| 674 |
+
- **Description**: Evaluate every N steps
|
| 675 |
+
- **Used in**: Lines ~491 in `run_cpt.py`
|
| 676 |
+
- **Implementation**:
|
| 677 |
+
```python
|
| 678 |
+
eval_steps=int(tr_cfg.get("eval_steps", 200))
|
| 679 |
+
```
|
| 680 |
+
- **Example Values**: `25`, `50`, `100`, `200`
|
| 681 |
+
|
| 682 |
+
### `train.load_best_model_at_end`
|
| 683 |
+
- **Type**: Boolean
|
| 684 |
+
- **Required**: No
|
| 685 |
+
- **Default**: `true` (if eval data), `false` (otherwise)
|
| 686 |
+
- **Description**: Load best model at end of training
|
| 687 |
+
- **Used in**: Lines ~492-493 in `run_cpt.py`
|
| 688 |
+
- **Implementation**:
|
| 689 |
+
```python
|
| 690 |
+
load_best_model_at_end=bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False
|
| 691 |
+
```
|
| 692 |
+
- **Example Values**: `true`, `false`
|
| 693 |
+
|
| 694 |
+
### `train.resume_from_checkpoint`
|
| 695 |
+
- **Type**: String
|
| 696 |
+
- **Required**: No
|
| 697 |
+
- **Default**: `"auto"`
|
| 698 |
+
- **Description**: Resume training from checkpoint
|
| 699 |
+
- **Used in**: Lines ~510-520 in `run_cpt.py`
|
| 700 |
+
- **Implementation**:
|
| 701 |
+
```python
|
| 702 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 703 |
+
if resume_from == "auto":
|
| 704 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 705 |
+
resume_from = last if last else None
|
| 706 |
+
```
|
| 707 |
+
- **Example Values**:
|
| 708 |
+
- `"auto"` - Auto-detect latest checkpoint
|
| 709 |
+
- `"checkpoint-100"` - Specific checkpoint
|
| 710 |
+
- `null` - Start from scratch
|
| 711 |
+
|
| 712 |
+
---
|
| 713 |
+
|
| 714 |
+
## Merge Parameters
|
| 715 |
+
|
| 716 |
+
### `merge.enabled`
|
| 717 |
+
- **Type**: Boolean
|
| 718 |
+
- **Required**: No
|
| 719 |
+
- **Default**: `false`
|
| 720 |
+
- **Description**: Whether to merge LoRA adapters with base model
|
| 721 |
+
- **Used in**: Lines ~545 in `run_cpt.py`
|
| 722 |
+
- **Implementation**:
|
| 723 |
+
```python
|
| 724 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 725 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 726 |
+
```
|
| 727 |
+
- **Example Values**: `true`, `false`
|
| 728 |
+
|
| 729 |
+
### `merge.merged_dtype`
|
| 730 |
+
- **Type**: String
|
| 731 |
+
- **Required**: No
|
| 732 |
+
- **Default**: `"float16"`
|
| 733 |
+
- **Description**: Data type for merged model
|
| 734 |
+
- **Used in**: Lines ~430 in `run_cpt.py`
|
| 735 |
+
- **Implementation**:
|
| 736 |
+
```python
|
| 737 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 738 |
+
```
|
| 739 |
+
- **Example Values**: `"float16"`, `"bfloat16"`, `"float32"`
|
| 740 |
+
|
| 741 |
+
### `merge.max_shard_size`
|
| 742 |
+
- **Type**: String
|
| 743 |
+
- **Required**: No
|
| 744 |
+
- **Default**: `"2GB"`
|
| 745 |
+
- **Description**: Maximum size per shard when saving
|
| 746 |
+
- **Used in**: Lines ~445 in `run_cpt.py`
|
| 747 |
+
- **Implementation**:
|
| 748 |
+
```python
|
| 749 |
+
merged.save_pretrained(str(final_dir), safe_serialization=True, max_shard_size=max_shard_size)
|
| 750 |
+
```
|
| 751 |
+
- **Example Values**: `"1GB"`, `"2GB"`, `"5GB"`
|
| 752 |
+
|
| 753 |
+
### `merge.output_dir`
|
| 754 |
+
- **Type**: String (path)
|
| 755 |
+
- **Required**: No
|
| 756 |
+
- **Default**: `"./merged_model"`
|
| 757 |
+
- **Description**: Directory for merged model output
|
| 758 |
+
- **Used in**: Lines ~505-510 in `run_cpt.py`
|
| 759 |
+
- **Implementation**:
|
| 760 |
+
```python
|
| 761 |
+
if merge_cfg.get("output_dir"):
|
| 762 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 763 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 764 |
+
else:
|
| 765 |
+
final_dir = run_dir / "final_model"
|
| 766 |
+
```
|
| 767 |
+
- **Example Values**: `"./merged_model"`, `"/workspace/final_model"`, `"./models/merged"`
|
| 768 |
+
|
| 769 |
+
---
|
| 770 |
+
|
| 771 |
+
## Parameter Dependencies and Interactions
|
| 772 |
+
|
| 773 |
+
### Memory-Related Dependencies
|
| 774 |
+
- `per_device_train_batch_size` + `gradient_accumulation_steps` = effective batch size
|
| 775 |
+
- `block_size` affects memory usage significantly
|
| 776 |
+
- `use_4bit` + `bnb_4bit_*` parameters work together for quantization
|
| 777 |
+
- `gradient_checkpointing` can enable larger `block_size` or `batch_size`
|
| 778 |
+
|
| 779 |
+
### Training Strategy Dependencies
|
| 780 |
+
- `evaluation_strategy` requires either `eval_jsonl` or `eval_split_ratio > 0`
|
| 781 |
+
- `load_best_model_at_end` requires `evaluation_strategy` to be enabled
|
| 782 |
+
- `save_strategy` should be compatible with `evaluation_strategy`
|
| 783 |
+
- `lr_scheduler_type` affects warmup calculations
|
| 784 |
+
|
| 785 |
+
### Model-Specific Dependencies
|
| 786 |
+
- `target_modules` must match the actual module names in your model
|
| 787 |
+
- `torch_dtype` should be compatible with your GPU hardware
|
| 788 |
+
- `device_map` affects whether you can use certain optimizations
|
| 789 |
+
|
| 790 |
+
### Data Processing Dependencies
|
| 791 |
+
- `text_field` must exist in your JSONL data
|
| 792 |
+
- `pack_mode: "pad"` requires `block_size` to be set appropriately
|
| 793 |
+
- `eval_split_ratio` is ignored if `eval_jsonl` is provided
|
| 794 |
+
|
| 795 |
+
This comprehensive documentation should help you understand and configure all parameters in the CPT training system according to your specific needs and constraints.
|
trainer-kit/CPT/dummy_data.jsonl
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"text": "This is a test sentence for the dummy dataset."}
|
| 2 |
+
{"text": "Another sentence to check if training works."}
|
| 3 |
+
{"text": "We need enough data to form a batch."}
|
| 4 |
+
{"text": "FSDP and LoRA are cool technologies."}
|
| 5 |
+
{"text": "Fine-tuning LLMs is fun and useful."}
|
| 6 |
+
{"text": "This is the end of the dummy dataset."}
|
trainer-kit/CPT/requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
transformers>=4.41.0
|
| 4 |
+
datasets>=2.18.0
|
| 5 |
+
accelerate>=0.30.0
|
| 6 |
+
|
| 7 |
+
# PEFT / QLoRA
|
| 8 |
+
peft>=0.11.1
|
| 9 |
+
bitsandbytes>=0.43.1
|
| 10 |
+
|
| 11 |
+
# Hugging Face Hub (local + download support)
|
| 12 |
+
huggingface_hub>=0.23.0
|
| 13 |
+
|
| 14 |
+
# Config + utilities
|
| 15 |
+
pyyaml>=6.0
|
| 16 |
+
tqdm>=4.66.0
|
| 17 |
+
|
| 18 |
+
# Optional but recommended (tokenizers speed)
|
| 19 |
+
tokenizers>=0.15.0
|
| 20 |
+
safetensors>=0.4.2
|
| 21 |
+
# Optional (for eval)
|
| 22 |
+
rouge-score>=0.1.2
|
trainer-kit/CPT/run_cpt.py
ADDED
|
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import inspect # Added for Transformers version compatibility
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import yaml
|
| 11 |
+
from datasets import load_dataset, DatasetDict
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoModelForCausalLM,
|
| 15 |
+
AutoTokenizer,
|
| 16 |
+
PreTrainedTokenizerFast,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
Trainer,
|
| 19 |
+
TrainerCallback,
|
| 20 |
+
default_data_collator,
|
| 21 |
+
set_seed,
|
| 22 |
+
)
|
| 23 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 24 |
+
from peft import (
|
| 25 |
+
LoraConfig,
|
| 26 |
+
get_peft_model,
|
| 27 |
+
prepare_model_for_kbit_training,
|
| 28 |
+
PeftModel,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from transformers import BitsAndBytesConfig
|
| 33 |
+
except ImportError: # older transformers
|
| 34 |
+
BitsAndBytesConfig = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# --------------------------
|
| 38 |
+
# Helpers
|
| 39 |
+
# --------------------------
|
| 40 |
+
|
| 41 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 42 |
+
s = (s or "").lower()
|
| 43 |
+
if s in ("float16", "fp16"):
|
| 44 |
+
return torch.float16
|
| 45 |
+
if s in ("bfloat16", "bf16"):
|
| 46 |
+
return torch.bfloat16
|
| 47 |
+
if s in ("float32", "fp32"):
|
| 48 |
+
return torch.float32
|
| 49 |
+
raise ValueError(f"Unknown torch_dtype: {s}")
|
| 50 |
+
|
| 51 |
+
def _now_iso() -> str:
|
| 52 |
+
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
| 53 |
+
|
| 54 |
+
def _safe_exp(x: float) -> float:
|
| 55 |
+
x = min(float(x), 50.0)
|
| 56 |
+
return float(math.exp(x))
|
| 57 |
+
|
| 58 |
+
def _ensure_dir(p: Path) -> Path:
|
| 59 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
return p
|
| 61 |
+
|
| 62 |
+
def _looks_like_model_dir(p: Path) -> bool:
|
| 63 |
+
if not p.exists() or not p.is_dir():
|
| 64 |
+
return False
|
| 65 |
+
if (p / "config.json").exists():
|
| 66 |
+
return True
|
| 67 |
+
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
|
| 68 |
+
return True
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
def _detect_text_field(example: Dict[str, Any]) -> Optional[str]:
|
| 72 |
+
for k, v in example.items():
|
| 73 |
+
if isinstance(v, str) and v.strip():
|
| 74 |
+
return k
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
def _load_tokenizer(base_dir: Path, use_fast: bool, trust_remote_code: bool):
|
| 78 |
+
try:
|
| 79 |
+
return AutoTokenizer.from_pretrained(
|
| 80 |
+
str(base_dir),
|
| 81 |
+
use_fast=use_fast,
|
| 82 |
+
trust_remote_code=trust_remote_code,
|
| 83 |
+
)
|
| 84 |
+
except ValueError as e:
|
| 85 |
+
if "TokenizersBackend" not in str(e):
|
| 86 |
+
raise
|
| 87 |
+
tok_file = base_dir / "tokenizer.json"
|
| 88 |
+
tok_cfg_path = base_dir / "tokenizer_config.json"
|
| 89 |
+
if not tok_file.exists():
|
| 90 |
+
raise
|
| 91 |
+
|
| 92 |
+
tok_kwargs: Dict[str, Any] = {}
|
| 93 |
+
if tok_cfg_path.exists():
|
| 94 |
+
with tok_cfg_path.open("r", encoding="utf-8") as f:
|
| 95 |
+
tok_cfg = json.load(f)
|
| 96 |
+
for key in ("bos_token", "eos_token", "pad_token", "unk_token", "model_max_length"):
|
| 97 |
+
if tok_cfg.get(key) is not None:
|
| 98 |
+
tok_kwargs[key] = tok_cfg[key]
|
| 99 |
+
extra = tok_cfg.get("additional_special_tokens") or tok_cfg.get("extra_special_tokens")
|
| 100 |
+
if extra:
|
| 101 |
+
tok_kwargs["additional_special_tokens"] = extra
|
| 102 |
+
|
| 103 |
+
return PreTrainedTokenizerFast(tokenizer_file=str(tok_file), **tok_kwargs)
|
| 104 |
+
|
| 105 |
+
def _infer_target_modules(model) -> List[str]:
|
| 106 |
+
names = set()
|
| 107 |
+
for n, _ in model.named_modules():
|
| 108 |
+
names.add(n.split(".")[-1])
|
| 109 |
+
|
| 110 |
+
for group in [
|
| 111 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 112 |
+
["Wqkv", "out_proj"],
|
| 113 |
+
["query_key_value", "dense"],
|
| 114 |
+
["c_attn", "c_proj"],
|
| 115 |
+
]:
|
| 116 |
+
if all(x in names for x in group):
|
| 117 |
+
return group
|
| 118 |
+
|
| 119 |
+
fallback = [x for x in ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense"] if x in names]
|
| 120 |
+
if fallback:
|
| 121 |
+
return fallback
|
| 122 |
+
|
| 123 |
+
raise ValueError("Could not auto-infer target_modules. Set peft.target_modules explicitly.")
|
| 124 |
+
|
| 125 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 126 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# --------------------------
|
| 130 |
+
# JSONL Logger Callback
|
| 131 |
+
# --------------------------
|
| 132 |
+
|
| 133 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 134 |
+
def __init__(self, run_dir: Path):
|
| 135 |
+
self.run_dir = run_dir
|
| 136 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 137 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 138 |
+
self.start_time = None
|
| 139 |
+
|
| 140 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 141 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 142 |
+
return None
|
| 143 |
+
elapsed = time.time() - self.start_time
|
| 144 |
+
sec_per_step = elapsed / global_step
|
| 145 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 146 |
+
h = int(remaining // 3600)
|
| 147 |
+
m = int((remaining % 3600) // 60)
|
| 148 |
+
s = int(remaining % 60)
|
| 149 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 150 |
+
|
| 151 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 152 |
+
self.start_time = time.time()
|
| 153 |
+
|
| 154 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 155 |
+
if not logs:
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 159 |
+
progress_pct = (100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 160 |
+
epoch_pct = None
|
| 161 |
+
if state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0:
|
| 162 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 163 |
+
|
| 164 |
+
payload = {
|
| 165 |
+
"ts": _now_iso(),
|
| 166 |
+
"event": "train_log",
|
| 167 |
+
"step": int(state.global_step),
|
| 168 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 169 |
+
"progress_pct": round(progress_pct, 2) if progress_pct is not None else None,
|
| 170 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 171 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 172 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 173 |
+
**logs,
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 177 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 178 |
+
|
| 179 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 180 |
+
if not metrics:
|
| 181 |
+
return
|
| 182 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 183 |
+
ppl = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 184 |
+
|
| 185 |
+
payload = {
|
| 186 |
+
"ts": _now_iso(),
|
| 187 |
+
"event": "eval",
|
| 188 |
+
"step": int(state.global_step),
|
| 189 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 190 |
+
**metrics,
|
| 191 |
+
"perplexity": ppl,
|
| 192 |
+
}
|
| 193 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 194 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --------------------------
|
| 198 |
+
# Data Pipeline (EOS + Packing)
|
| 199 |
+
# --------------------------
|
| 200 |
+
|
| 201 |
+
def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
|
| 202 |
+
data_cfg = cfg["data"]
|
| 203 |
+
train_path = data_cfg["train_jsonl"]
|
| 204 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 205 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 206 |
+
text_field = data_cfg.get("text_field", "text")
|
| 207 |
+
block_size = int(data_cfg.get("block_size", 2048))
|
| 208 |
+
shuffle = bool(data_cfg.get("shuffle", True))
|
| 209 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 210 |
+
|
| 211 |
+
pack_mode = str(data_cfg.get("pack_mode", "drop")).lower().strip()
|
| 212 |
+
if pack_mode not in ("drop", "pad"):
|
| 213 |
+
raise ValueError(f"data.pack_mode must be 'drop' or 'pad', got: {pack_mode}")
|
| 214 |
+
|
| 215 |
+
eos_id = tokenizer.eos_token_id
|
| 216 |
+
if eos_id is None:
|
| 217 |
+
raise ValueError("Tokenizer has no eos_token_id; CPT packing needs an EOS delimiter.")
|
| 218 |
+
|
| 219 |
+
if tokenizer.pad_token_id is None:
|
| 220 |
+
# safe default for many causal LMs
|
| 221 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 222 |
+
pad_id = tokenizer.pad_token_id
|
| 223 |
+
|
| 224 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 225 |
+
|
| 226 |
+
if eval_path:
|
| 227 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 228 |
+
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
|
| 229 |
+
else:
|
| 230 |
+
if 0.0 < split_ratio < 1.0:
|
| 231 |
+
split = ds["train"].train_test_split(test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)))
|
| 232 |
+
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
|
| 233 |
+
else:
|
| 234 |
+
dsd = DatasetDict({"train": ds["train"], "eval": None})
|
| 235 |
+
|
| 236 |
+
if text_field not in dsd["train"].column_names:
|
| 237 |
+
auto_field = _detect_text_field(dsd["train"][0])
|
| 238 |
+
if not auto_field:
|
| 239 |
+
raise ValueError(f"Could not find text field. Columns: {dsd['train'].column_names}")
|
| 240 |
+
text_field = auto_field
|
| 241 |
+
|
| 242 |
+
def tokenize_fn(examples):
|
| 243 |
+
out = tokenizer(
|
| 244 |
+
examples[text_field],
|
| 245 |
+
add_special_tokens=False,
|
| 246 |
+
truncation=False,
|
| 247 |
+
padding=False,
|
| 248 |
+
)
|
| 249 |
+
if "token_type_ids" in out:
|
| 250 |
+
del out["token_type_ids"]
|
| 251 |
+
# Add EOS between docs
|
| 252 |
+
out["input_ids"] = [ids + [eos_id] for ids in out["input_ids"]]
|
| 253 |
+
out["attention_mask"] = [m + [1] for m in out["attention_mask"]]
|
| 254 |
+
return out
|
| 255 |
+
|
| 256 |
+
tokenized_train = dsd["train"].map(
|
| 257 |
+
tokenize_fn,
|
| 258 |
+
batched=True,
|
| 259 |
+
num_proc=num_proc,
|
| 260 |
+
remove_columns=dsd["train"].column_names,
|
| 261 |
+
desc="Tokenizing train",
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
tokenized_eval = None
|
| 265 |
+
if dsd["eval"] is not None:
|
| 266 |
+
tokenized_eval = dsd["eval"].map(
|
| 267 |
+
tokenize_fn,
|
| 268 |
+
batched=True,
|
| 269 |
+
num_proc=num_proc,
|
| 270 |
+
remove_columns=dsd["eval"].column_names,
|
| 271 |
+
desc="Tokenizing eval",
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def group_texts(examples):
|
| 275 |
+
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
|
| 276 |
+
total_length = len(concatenated["input_ids"])
|
| 277 |
+
|
| 278 |
+
if total_length == 0:
|
| 279 |
+
return {"input_ids": [], "attention_mask": [], "labels": []}
|
| 280 |
+
|
| 281 |
+
full_len = (total_length // block_size) * block_size
|
| 282 |
+
blocks_input, blocks_attn, blocks_labels = [], [], []
|
| 283 |
+
|
| 284 |
+
# full blocks
|
| 285 |
+
for i in range(0, full_len, block_size):
|
| 286 |
+
chunk = concatenated["input_ids"][i:i + block_size]
|
| 287 |
+
attn = concatenated["attention_mask"][i:i + block_size]
|
| 288 |
+
blocks_input.append(chunk)
|
| 289 |
+
blocks_attn.append(attn)
|
| 290 |
+
blocks_labels.append(chunk.copy())
|
| 291 |
+
|
| 292 |
+
# remainder
|
| 293 |
+
remainder = total_length - full_len
|
| 294 |
+
if remainder > 0 and pack_mode == "pad":
|
| 295 |
+
chunk = concatenated["input_ids"][full_len:full_len + remainder]
|
| 296 |
+
attn = concatenated["attention_mask"][full_len:full_len + remainder]
|
| 297 |
+
|
| 298 |
+
pad_len = block_size - remainder
|
| 299 |
+
chunk_padded = chunk + [pad_id] * pad_len
|
| 300 |
+
attn_padded = attn + [0] * pad_len
|
| 301 |
+
|
| 302 |
+
labels = chunk_padded.copy()
|
| 303 |
+
labels[-pad_len:] = [-100] * pad_len # loss mask
|
| 304 |
+
|
| 305 |
+
blocks_input.append(chunk_padded)
|
| 306 |
+
blocks_attn.append(attn_padded)
|
| 307 |
+
blocks_labels.append(labels)
|
| 308 |
+
|
| 309 |
+
return {
|
| 310 |
+
"input_ids": blocks_input,
|
| 311 |
+
"attention_mask": blocks_attn,
|
| 312 |
+
"labels": blocks_labels,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
tokenized_train = tokenized_train.map(
|
| 316 |
+
group_texts,
|
| 317 |
+
batched=True,
|
| 318 |
+
num_proc=num_proc,
|
| 319 |
+
desc=f"Packing train blocks (mode={pack_mode})",
|
| 320 |
+
)
|
| 321 |
+
if tokenized_eval is not None:
|
| 322 |
+
tokenized_eval = tokenized_eval.map(
|
| 323 |
+
group_texts,
|
| 324 |
+
batched=True,
|
| 325 |
+
num_proc=num_proc,
|
| 326 |
+
desc=f"Packing eval blocks (mode={pack_mode})",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if len(tokenized_train) == 0:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
"Train dataset is empty after packing. "
|
| 332 |
+
"Either increase data, reduce block_size, or set data.pack_mode='pad'."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if shuffle:
|
| 336 |
+
tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 337 |
+
|
| 338 |
+
return tokenized_train, tokenized_eval
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# --------------------------
|
| 342 |
+
# Model Loading + PEFT
|
| 343 |
+
# --------------------------
|
| 344 |
+
|
| 345 |
+
def _select_model_loader(base_dir: Path):
|
| 346 |
+
cfg_path = base_dir / "config.json"
|
| 347 |
+
if not cfg_path.exists():
|
| 348 |
+
return {"kind": "causal", "arch": None}
|
| 349 |
+
with cfg_path.open("r", encoding="utf-8") as f:
|
| 350 |
+
cfg = json.load(f)
|
| 351 |
+
arch = cfg.get("architectures") or []
|
| 352 |
+
arch_name = arch[0] if arch else None
|
| 353 |
+
if any("ForConditionalGeneration" in a for a in arch):
|
| 354 |
+
return {"kind": "conditional", "arch": arch_name}
|
| 355 |
+
return {"kind": "causal", "arch": arch_name}
|
| 356 |
+
|
| 357 |
+
def _resolve_model_class(arch_name: str):
|
| 358 |
+
import transformers
|
| 359 |
+
cls = getattr(transformers, arch_name, None)
|
| 360 |
+
if cls is None:
|
| 361 |
+
raise ValueError(f"Model class '{arch_name}' is not available in installed transformers.")
|
| 362 |
+
return cls
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
|
| 366 |
+
model_cfg = cfg["model"]
|
| 367 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 368 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 369 |
+
device_map = model_cfg.get("device_map", "auto")
|
| 370 |
+
|
| 371 |
+
tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
|
| 372 |
+
if tokenizer.pad_token is None:
|
| 373 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 374 |
+
|
| 375 |
+
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 376 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 377 |
+
|
| 378 |
+
quant_cfg = None
|
| 379 |
+
if use_4bit:
|
| 380 |
+
if BitsAndBytesConfig is None:
|
| 381 |
+
raise ImportError("BitsAndBytesConfig is not available in this transformers version.")
|
| 382 |
+
quant_cfg = BitsAndBytesConfig(
|
| 383 |
+
load_in_4bit=True,
|
| 384 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
|
| 385 |
+
bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
|
| 386 |
+
bnb_4bit_compute_dtype=_dtype_from_str(model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")),
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
attn_impl = _choose_attn_impl(cfg)
|
| 390 |
+
model_meta = _select_model_loader(base_dir)
|
| 391 |
+
|
| 392 |
+
try:
|
| 393 |
+
if model_meta["kind"] == "conditional":
|
| 394 |
+
model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
|
| 395 |
+
if model_cls is None:
|
| 396 |
+
raise ValueError("Conditional model architecture not specified in config.json.")
|
| 397 |
+
model = model_cls.from_pretrained(
|
| 398 |
+
str(base_dir),
|
| 399 |
+
device_map=device_map,
|
| 400 |
+
trust_remote_code=trust_remote_code,
|
| 401 |
+
low_cpu_mem_usage=True,
|
| 402 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 403 |
+
quantization_config=quant_cfg,
|
| 404 |
+
attn_implementation=attn_impl,
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 408 |
+
str(base_dir),
|
| 409 |
+
device_map=device_map,
|
| 410 |
+
trust_remote_code=trust_remote_code,
|
| 411 |
+
low_cpu_mem_usage=True,
|
| 412 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 413 |
+
quantization_config=quant_cfg,
|
| 414 |
+
attn_implementation=attn_impl,
|
| 415 |
+
)
|
| 416 |
+
except Exception as e:
|
| 417 |
+
if attn_impl is not None:
|
| 418 |
+
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
|
| 419 |
+
print("[warn] Falling back to default attention implementation.")
|
| 420 |
+
if model_meta["kind"] == "conditional":
|
| 421 |
+
model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
|
| 422 |
+
if model_cls is None:
|
| 423 |
+
raise ValueError("Conditional model architecture not specified in config.json.")
|
| 424 |
+
model = model_cls.from_pretrained(
|
| 425 |
+
str(base_dir),
|
| 426 |
+
device_map=device_map,
|
| 427 |
+
trust_remote_code=trust_remote_code,
|
| 428 |
+
low_cpu_mem_usage=True,
|
| 429 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 430 |
+
quantization_config=quant_cfg,
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 434 |
+
str(base_dir),
|
| 435 |
+
device_map=device_map,
|
| 436 |
+
trust_remote_code=trust_remote_code,
|
| 437 |
+
low_cpu_mem_usage=True,
|
| 438 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 439 |
+
quantization_config=quant_cfg,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
return model, tokenizer
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def apply_peft(cfg: Dict[str, Any], model):
|
| 446 |
+
peft_cfg = cfg["peft"]
|
| 447 |
+
model_cfg = cfg["model"]
|
| 448 |
+
tr_cfg = cfg["train"]
|
| 449 |
+
|
| 450 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 451 |
+
return model, None
|
| 452 |
+
|
| 453 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 454 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 455 |
+
|
| 456 |
+
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
|
| 457 |
+
model.gradient_checkpointing_enable()
|
| 458 |
+
if hasattr(model, "config"):
|
| 459 |
+
model.config.use_cache = False
|
| 460 |
+
|
| 461 |
+
if use_4bit:
|
| 462 |
+
model = prepare_model_for_kbit_training(
|
| 463 |
+
model,
|
| 464 |
+
use_gradient_checkpointing=gradient_checkpointing,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 468 |
+
if target_modules == "auto":
|
| 469 |
+
target_modules = _infer_target_modules(model)
|
| 470 |
+
|
| 471 |
+
lora_config = LoraConfig(
|
| 472 |
+
r=int(peft_cfg.get("r", 16)),
|
| 473 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 474 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 475 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 476 |
+
task_type="CAUSAL_LM",
|
| 477 |
+
target_modules=target_modules,
|
| 478 |
+
)
|
| 479 |
+
model = get_peft_model(model, lora_config)
|
| 480 |
+
return model, lora_config
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# --------------------------
|
| 484 |
+
# Merge Logic
|
| 485 |
+
# --------------------------
|
| 486 |
+
|
| 487 |
+
def merge_adapter(cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path):
|
| 488 |
+
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
|
| 489 |
+
|
| 490 |
+
model_cfg = cfg["model"]
|
| 491 |
+
merge_cfg = cfg.get("merge", {})
|
| 492 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 493 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 494 |
+
|
| 495 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 496 |
+
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
|
| 497 |
+
|
| 498 |
+
model_meta = _select_model_loader(base_dir)
|
| 499 |
+
if model_meta["kind"] == "conditional":
|
| 500 |
+
base_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
|
| 501 |
+
if base_cls is None:
|
| 502 |
+
raise ValueError("Conditional model architecture not specified in config.json.")
|
| 503 |
+
base = base_cls.from_pretrained(
|
| 504 |
+
str(base_dir),
|
| 505 |
+
torch_dtype=merged_dtype,
|
| 506 |
+
device_map="cpu",
|
| 507 |
+
low_cpu_mem_usage=True,
|
| 508 |
+
trust_remote_code=trust_remote_code,
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 512 |
+
str(base_dir),
|
| 513 |
+
torch_dtype=merged_dtype,
|
| 514 |
+
device_map="cpu",
|
| 515 |
+
low_cpu_mem_usage=True,
|
| 516 |
+
trust_remote_code=trust_remote_code,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
merged = PeftModel.from_pretrained(base, str(adapter_dir))
|
| 520 |
+
merged = merged.merge_and_unload()
|
| 521 |
+
|
| 522 |
+
_ensure_dir(final_dir)
|
| 523 |
+
# Fix for transformers weight conversion bug with quantized models
|
| 524 |
+
# Clear weight conversions to avoid NotImplementedError in reverse_transform
|
| 525 |
+
if hasattr(merged, '_weight_conversions'):
|
| 526 |
+
merged._weight_conversions = []
|
| 527 |
+
merged.save_pretrained(
|
| 528 |
+
str(final_dir),
|
| 529 |
+
safe_serialization=True,
|
| 530 |
+
max_shard_size=max_shard_size
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
tok = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
|
| 534 |
+
if tok.pad_token is None:
|
| 535 |
+
tok.pad_token = tok.eos_token
|
| 536 |
+
tok.save_pretrained(str(final_dir))
|
| 537 |
+
|
| 538 |
+
print("--- Merge complete ---")
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
# --------------------------
|
| 542 |
+
# Main
|
| 543 |
+
# --------------------------
|
| 544 |
+
|
| 545 |
+
def main():
|
| 546 |
+
ap = argparse.ArgumentParser()
|
| 547 |
+
ap.add_argument("--config", required=True, help="Path to YAML config")
|
| 548 |
+
ap.add_argument("--merge-only", action="store_true", help="Skip training, just merge adapter")
|
| 549 |
+
args = ap.parse_args()
|
| 550 |
+
|
| 551 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 552 |
+
cfg = yaml.safe_load(f)
|
| 553 |
+
|
| 554 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 555 |
+
_ensure_dir(run_dir / "logs")
|
| 556 |
+
|
| 557 |
+
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
|
| 558 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 559 |
+
|
| 560 |
+
model_cfg = cfg["model"]
|
| 561 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 562 |
+
repo_path = Path(repo_id)
|
| 563 |
+
|
| 564 |
+
# ✅ Local model path -> load directly; no download
|
| 565 |
+
if repo_path.exists() and repo_path.is_dir():
|
| 566 |
+
base_dir = repo_path
|
| 567 |
+
if not _looks_like_model_dir(base_dir):
|
| 568 |
+
raise ValueError(f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}")
|
| 569 |
+
else:
|
| 570 |
+
# HF repo_id -> download into run_dir/base_local_dir
|
| 571 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 572 |
+
if not _looks_like_model_dir(base_dir):
|
| 573 |
+
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
|
| 574 |
+
snapshot_download(
|
| 575 |
+
repo_id=repo_id,
|
| 576 |
+
revision=model_cfg.get("revision", None),
|
| 577 |
+
local_dir=str(base_dir),
|
| 578 |
+
local_dir_use_symlinks=False,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
|
| 582 |
+
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
|
| 583 |
+
|
| 584 |
+
merge_cfg = cfg.get("merge", {}) or {}
|
| 585 |
+
if merge_cfg.get("output_dir"):
|
| 586 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 587 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 588 |
+
else:
|
| 589 |
+
final_dir = run_dir / "final_model"
|
| 590 |
+
|
| 591 |
+
# Merge-only
|
| 592 |
+
if args.merge_only:
|
| 593 |
+
if not _looks_like_model_dir(best_adapter_dir):
|
| 594 |
+
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
|
| 595 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 596 |
+
return
|
| 597 |
+
|
| 598 |
+
# Training
|
| 599 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 600 |
+
|
| 601 |
+
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
|
| 602 |
+
model, _ = apply_peft(cfg, model)
|
| 603 |
+
|
| 604 |
+
train_ds, eval_ds = build_datasets(cfg, tokenizer)
|
| 605 |
+
|
| 606 |
+
tr_cfg = cfg["train"]
|
| 607 |
+
|
| 608 |
+
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 609 |
+
use_fp16 = (dtype == torch.float16)
|
| 610 |
+
use_bf16 = (dtype == torch.bfloat16)
|
| 611 |
+
|
| 612 |
+
max_steps = int(tr_cfg.get("max_steps", 0))
|
| 613 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 614 |
+
|
| 615 |
+
# --- Dynamic evaluation strategy parameter handling ---
|
| 616 |
+
ta_params = inspect.signature(TrainingArguments.__init__).parameters
|
| 617 |
+
eval_key = "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
|
| 618 |
+
|
| 619 |
+
desired_ta_kwargs = dict(
|
| 620 |
+
output_dir=str(ckpt_dir),
|
| 621 |
+
max_steps=max_steps if max_steps > 0 else -1,
|
| 622 |
+
num_train_epochs=num_train_epochs,
|
| 623 |
+
|
| 624 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
|
| 625 |
+
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1))),
|
| 626 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
|
| 627 |
+
|
| 628 |
+
learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
|
| 629 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
|
| 630 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 631 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
|
| 632 |
+
|
| 633 |
+
optim=str(tr_cfg.get("optim", "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch")),
|
| 634 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
|
| 635 |
+
|
| 636 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 637 |
+
|
| 638 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 639 |
+
save_steps=int(tr_cfg.get("save_steps", 200)),
|
| 640 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 641 |
+
|
| 642 |
+
eval_steps=int(tr_cfg.get("eval_steps", 200)),
|
| 643 |
+
|
| 644 |
+
load_best_model_at_end=bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False,
|
| 645 |
+
metric_for_best_model="eval_loss",
|
| 646 |
+
greater_is_better=False,
|
| 647 |
+
|
| 648 |
+
fp16=use_fp16,
|
| 649 |
+
bf16=use_bf16,
|
| 650 |
+
|
| 651 |
+
report_to=[],
|
| 652 |
+
remove_unused_columns=False,
|
| 653 |
+
save_safetensors=True,
|
| 654 |
+
overwrite_output_dir=False,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# Set the correct argument name for this transformers version
|
| 658 |
+
desired_ta_kwargs[eval_key] = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
|
| 659 |
+
ta_kwargs = {k: v for k, v in desired_ta_kwargs.items() if k in ta_params}
|
| 660 |
+
|
| 661 |
+
training_args = TrainingArguments(**ta_kwargs)
|
| 662 |
+
|
| 663 |
+
trainer_params = inspect.signature(Trainer.__init__).parameters
|
| 664 |
+
desired_trainer_kwargs = dict(
|
| 665 |
+
model=model,
|
| 666 |
+
args=training_args,
|
| 667 |
+
train_dataset=train_ds,
|
| 668 |
+
eval_dataset=eval_ds,
|
| 669 |
+
tokenizer=tokenizer,
|
| 670 |
+
processing_class=tokenizer,
|
| 671 |
+
data_collator=default_data_collator,
|
| 672 |
+
callbacks=[JsonlLoggerCallback(run_dir)],
|
| 673 |
+
)
|
| 674 |
+
trainer_kwargs = {k: v for k, v in desired_trainer_kwargs.items() if k in trainer_params}
|
| 675 |
+
trainer = Trainer(**trainer_kwargs)
|
| 676 |
+
|
| 677 |
+
# Resume
|
| 678 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 679 |
+
if resume_from == "auto":
|
| 680 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 681 |
+
resume_from = last if last else None
|
| 682 |
+
if resume_from:
|
| 683 |
+
print(f"Resuming from {resume_from}")
|
| 684 |
+
|
| 685 |
+
print("Starting training...")
|
| 686 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 687 |
+
|
| 688 |
+
trainer.save_model(str(best_adapter_dir))
|
| 689 |
+
print(f"Saved best adapter -> {best_adapter_dir}")
|
| 690 |
+
|
| 691 |
+
if eval_ds is not None:
|
| 692 |
+
metrics = trainer.evaluate()
|
| 693 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 694 |
+
metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 695 |
+
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
|
| 696 |
+
json.dump(metrics, f, indent=2)
|
| 697 |
+
print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
|
| 698 |
+
|
| 699 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 700 |
+
del trainer, model
|
| 701 |
+
torch.cuda.empty_cache()
|
| 702 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 703 |
+
else:
|
| 704 |
+
print("Merge disabled. Run with --merge-only later if needed.")
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
if __name__ == "__main__":
|
| 708 |
+
main()
|
trainer-kit/SFT-14b/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
trainer-kit/SFT-14b/config_instruct.yaml
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/instruct_run_14b_v1"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# WandB integration for experiment tracking
|
| 6 |
+
wandb:
|
| 7 |
+
enabled: true # Set to true to enable wandb logging
|
| 8 |
+
project: "sft-training" # WandB project name
|
| 9 |
+
entity: null # WandB entity/team (optional)
|
| 10 |
+
name: null # Run name (optional, will auto-generate if null)
|
| 11 |
+
tags: ["sft-lora", "instruction-tuning"] # List of tags for the run (e.g., ["lora", "qlora", "experiment-1"])
|
| 12 |
+
notes: null # Run description/notes (optional)
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
# Use local Qwen2.5-Coder-14B model
|
| 16 |
+
repo_id: "./runs/cpt_run_14b/merged_14b_cpt_lora"
|
| 17 |
+
revision: null
|
| 18 |
+
|
| 19 |
+
# Used only when repo_id is a HF repo (not a local path)
|
| 20 |
+
base_local_dir: "base_model"
|
| 21 |
+
|
| 22 |
+
trust_remote_code: true
|
| 23 |
+
tokenizer_use_fast: true
|
| 24 |
+
device_map: "auto"
|
| 25 |
+
|
| 26 |
+
torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
|
| 27 |
+
|
| 28 |
+
# QLoRA
|
| 29 |
+
use_4bit: false
|
| 30 |
+
bnb_4bit_quant_type: "nf4"
|
| 31 |
+
bnb_4bit_use_double_quant: false
|
| 32 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 33 |
+
|
| 34 |
+
# optional: "flash_attention_2" | "sdpa" | null
|
| 35 |
+
attn_implementation: null
|
| 36 |
+
|
| 37 |
+
data:
|
| 38 |
+
train_jsonl: "sft_dataset.jsonl"
|
| 39 |
+
eval_jsonl: null
|
| 40 |
+
eval_split_ratio: 0.1
|
| 41 |
+
|
| 42 |
+
# Field names in your JSONL data
|
| 43 |
+
instruction_field: "instruction" # This will be the system prompt
|
| 44 |
+
input_field: "input" # This is the task description
|
| 45 |
+
output_field: "output" # This is the analysis + selection
|
| 46 |
+
|
| 47 |
+
# Formatting options
|
| 48 |
+
format_type: "custom" # "chatml" | "alpaca" | "custom"
|
| 49 |
+
|
| 50 |
+
# For chatml format
|
| 51 |
+
system_prompt: |
|
| 52 |
+
You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 53 |
+
|
| 54 |
+
## Output Format
|
| 55 |
+
|
| 56 |
+
##OUTPUT
|
| 57 |
+
Explain the data flow and why each component must change:
|
| 58 |
+
- Flow: [Input → Processing → Output with arrows]
|
| 59 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
|
| 60 |
+
- Explain coupling between components
|
| 61 |
+
|
| 62 |
+
##SELECT
|
| 63 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 64 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 65 |
+
<EOS>
|
| 66 |
+
|
| 67 |
+
## Rules
|
| 68 |
+
|
| 69 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 70 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 71 |
+
3. Always explain "must change because" and "without this"
|
| 72 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 73 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 74 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 75 |
+
|
| 76 |
+
## Example
|
| 77 |
+
|
| 78 |
+
##TASK
|
| 79 |
+
Add webhook subscription support
|
| 80 |
+
|
| 81 |
+
##OUTPUT
|
| 82 |
+
The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings.
|
| 83 |
+
|
| 84 |
+
##SELECT
|
| 85 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 86 |
+
crates/common_enums/src/transformers.rs::SubscriptionStatus
|
| 87 |
+
<EOS>
|
| 88 |
+
|
| 89 |
+
# For custom format (only used when format_type="custom")
|
| 90 |
+
custom_template: "##INSTRUCTION\n{instruction}<|im_end|>\n##TASK\n{input}<|im_end|>\n##OUTPUT\n{output}<|im_end|>"
|
| 91 |
+
|
| 92 |
+
max_length: 2048
|
| 93 |
+
shuffle: true
|
| 94 |
+
num_proc: 4
|
| 95 |
+
|
| 96 |
+
peft:
|
| 97 |
+
enabled: true
|
| 98 |
+
r: 16
|
| 99 |
+
lora_alpha: 32
|
| 100 |
+
lora_dropout: 0.05
|
| 101 |
+
bias: "none"
|
| 102 |
+
target_modules: "auto"
|
| 103 |
+
|
| 104 |
+
train:
|
| 105 |
+
# max_steps: 10
|
| 106 |
+
num_train_epochs: 6
|
| 107 |
+
|
| 108 |
+
per_device_train_batch_size: 1
|
| 109 |
+
per_device_eval_batch_size: 1
|
| 110 |
+
gradient_accumulation_steps: 8
|
| 111 |
+
|
| 112 |
+
learning_rate: 2e-4
|
| 113 |
+
weight_decay: 0.0
|
| 114 |
+
warmup_ratio: 0.08
|
| 115 |
+
lr_scheduler_type: "cosine"
|
| 116 |
+
|
| 117 |
+
optim: "adamw_torch" # ✅ Changed from paged_adamw_8bit (requires use_4bit=true)
|
| 118 |
+
max_grad_norm: 1.0
|
| 119 |
+
gradient_checkpointing: true
|
| 120 |
+
|
| 121 |
+
logging_steps: 2
|
| 122 |
+
save_strategy: "steps"
|
| 123 |
+
save_steps: 500
|
| 124 |
+
save_total_limit: 20
|
| 125 |
+
|
| 126 |
+
evaluation_strategy: "steps"
|
| 127 |
+
eval_steps: 100
|
| 128 |
+
load_best_model_at_end: true
|
| 129 |
+
|
| 130 |
+
# Early stopping
|
| 131 |
+
early_stopping:
|
| 132 |
+
enabled: true
|
| 133 |
+
patience: 3 # Number of evaluations with no improvement before stopping
|
| 134 |
+
min_delta: 0.001 # Minimum change to qualify as improvement
|
| 135 |
+
metric: "eval_loss" # Metric to monitor
|
| 136 |
+
mode: "min" # "min" for loss, "max" for accuracy/etc.
|
| 137 |
+
|
| 138 |
+
resume_from_checkpoint: "auto"
|
| 139 |
+
|
| 140 |
+
merge:
|
| 141 |
+
enabled: true
|
| 142 |
+
merged_dtype: "float16"
|
| 143 |
+
max_shard_size: "2GB"
|
| 144 |
+
output_dir: "./merged_14b_instruct_lora"
|
trainer-kit/SFT-14b/instruct_data.jsonl
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"input": "##TASK\nAdd subscription support to webhook event metadata and outgoing webhook content systems to enable proper subscription lifecycle event handling and webhook delivery for subscription-related operations.", "output": "##OUTPUT\nThe webhook system requires subscription event support to handle subscription lifecycle events (creation, payment, cancellation) and deliver them to merchants via outgoing webhooks. This affects the event metadata enumeration and webhook content transformation layers. The complete flow is: subscription event occurs → EventMetadata::Subscription variant created with subscription_id, invoice_id, payment_id → stored in database → retrieved for webhook delivery → OutgoingWebhookEventContent::foreign_from converts EventMetadata to OutgoingWebhookEventContent::Subscription → StripeWebhookObject::From transforms to StripeWebhookObject::Subscriptions for Stripe compatibility. The OutgoingWebhookEventContent::foreign_from function (router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from) must add the Subscription variant match arm because it's the bridge from storage EventMetadata to webhook content representation—without this, subscription events cannot be converted to webhook content and would be silently dropped during webhook processing. The StripeWebhookObject::From implementation (router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From) must handle OutgoingWebhookContent::SubscriptionDetails because it provides Stripe-compatible webhook formatting—without this, merchants using Stripe webhook endpoints would receive malformed or missing subscription events. The EventMetadata enum (implicitly referenced in foreign_from) requires the Subscription variant to exist in the diesel_models layer, which enables persistent storage of subscription event references. These components form a tightly coupled webhook pipeline where each depends on the previous: EventMetadata → OutgoingWebhookEventContent → StripeWebhookObject. If foreign_from lacks the Subscription match arm, subscription events stored as EventMetadata::Subscription would cause panic during webhook delivery. If StripeWebhookObject::From lacks SubscriptionDetails handling, Stripe webhook endpoints would receive unparseable responses, breaking merchant integrations that rely on Stripe-compatible subscription webhooks for revenue recognition and customer management systems.\n\n##SELECT\ncrates/router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from\ncrates/router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From\n<EOS>"}
|
| 2 |
+
{"input": "##TASK\nUpdate subscription response mapping to use common_enums::SubscriptionStatus instead of subscription_types::SubscriptionStatus for consistent enum usage across the codebase.", "output": "##OUTPUT\nThe subscription system needs to standardize on a single SubscriptionStatus enum type to avoid type inconsistencies and import conflicts. The current flow is: subscription handler creates SubscriptionResponse → calls to_subscription_response method → converts string status to subscription_types::SubscriptionStatus → returns response. The to_subscription_response method (crates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response) must change because it's the conversion point where the status enum is instantiated. Using common_enums::SubscriptionStatus ensures consistency with other parts of the system that already use this common enum type. Without this change, the subscription module would maintain its own duplicate enum, creating potential mismatches and requiring conversions throughout the codebase when interacting with other modules that expect the common enum type.\n\n##SELECT\ncrates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response\n<EOS>"}
|
| 3 |
+
{"input": "##TASK\nAdd subscription support to webhook event system by including Subscriptions in EventClass enum and implementing SubscriptionStatus to EventType mapping for proper webhook routing.", "output": "##OUTPUT\nThe webhook system needs to handle subscription-related events for proper event classification and routing. The current flow is: webhook event received → EventClass determines routing category → EventType specifies the exact event type → handlers process based on event type. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must include Subscriptions because it's the primary classification that determines which subsystem handles the event. The SubscriptionStatus transformer (crates/common_enums/src/transformers.rs::impl::Option<EventType>) must map subscription status changes to appropriate EventType values because it converts domain state changes into webhook events. Without EventClass::Subscriptions, subscription events would be misclassified or unhandled. Without the SubscriptionStatus mapping, subscription status changes wouldn't trigger appropriate webhook notifications, breaking integration with external systems that depend on subscription lifecycle events.\n\n##SELECT\ncrates/common_enums/src/enums.rs::EventClass\ncrates/common_enums/src/transformers.rs::impl::Option<EventType>\n<EOS>"}
|
| 4 |
+
{"input": "##TASK\nAdd subscription ID support to webhook object reference system for proper webhook payload identification and routing.", "output": "##OUTPUT\nThe webhook system needs to reference subscription entities in event payloads for proper event correlation and processing. The current flow is: webhook event generated → ObjectReferenceId identifies the affected entity → webhook payload includes reference → consumers process based on entity type. The ObjectReferenceId enum (crates/api_models/src/webhooks.rs::ObjectReferenceId) must include SubscriptionId because it's the type-safe identifier used throughout the webhook payload structure to specify which subscription triggered the event. Without SubscriptionId, webhook events related to subscriptions couldn't properly reference the subscription entity, making it impossible for consumers to correlate events with specific subscriptions. This would break webhook consumers that need to update their local state or trigger business logic based on subscription events.\n\n##SELECT\ncrates/api_models/src/webhooks.rs::ObjectReferenceId\n<EOS>"}
|
trainer-kit/SFT-14b/requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
transformers>=4.41.0
|
| 4 |
+
datasets>=2.18.0
|
| 5 |
+
accelerate>=0.30.0
|
| 6 |
+
|
| 7 |
+
# PEFT / QLoRA
|
| 8 |
+
peft>=0.11.1
|
| 9 |
+
bitsandbytes>=0.43.1
|
| 10 |
+
|
| 11 |
+
# Hugging Face Hub (local + download support)
|
| 12 |
+
huggingface_hub>=0.23.0
|
| 13 |
+
|
| 14 |
+
# Config + utilities
|
| 15 |
+
pyyaml>=6.0
|
| 16 |
+
tqdm>=4.66.0
|
| 17 |
+
|
| 18 |
+
# Optional but recommended (tokenizers speed)
|
| 19 |
+
tokenizers>=0.15.0
|
| 20 |
+
safetensors>=0.4.2
|
| 21 |
+
|
| 22 |
+
# Experiment tracking
|
| 23 |
+
wandb>=0.16.0
|
trainer-kit/SFT-14b/run_instruct.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import inspect # Added for Transformers version compatibility
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import yaml
|
| 11 |
+
from datasets import load_dataset, DatasetDict
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForCausalLM,
|
| 16 |
+
BitsAndBytesConfig,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
Trainer,
|
| 19 |
+
TrainerCallback,
|
| 20 |
+
EarlyStoppingCallback,
|
| 21 |
+
default_data_collator,
|
| 22 |
+
set_seed,
|
| 23 |
+
)
|
| 24 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 25 |
+
from peft import (
|
| 26 |
+
LoraConfig,
|
| 27 |
+
get_peft_model,
|
| 28 |
+
prepare_model_for_kbit_training,
|
| 29 |
+
PeftModel,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import wandb
|
| 34 |
+
WANDB_AVAILABLE = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
WANDB_AVAILABLE = False
|
| 37 |
+
wandb = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# --------------------------
|
| 41 |
+
# Helpers
|
| 42 |
+
# --------------------------
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 46 |
+
s = (s or "").lower()
|
| 47 |
+
if s in ("float16", "fp16"):
|
| 48 |
+
return torch.float16
|
| 49 |
+
if s in ("bfloat16", "bf16"):
|
| 50 |
+
return torch.bfloat16
|
| 51 |
+
if s in ("float32", "fp32"):
|
| 52 |
+
return torch.float32
|
| 53 |
+
raise ValueError(f"Unknown torch_dtype: {s}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _now_iso() -> str:
|
| 57 |
+
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _safe_exp(x: float) -> float:
|
| 61 |
+
x = min(float(x), 50.0)
|
| 62 |
+
return float(math.exp(x))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _ensure_dir(p: Path) -> Path:
|
| 66 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
return p
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _looks_like_model_dir(p: Path) -> bool:
|
| 71 |
+
if not p.exists() or not p.is_dir():
|
| 72 |
+
return False
|
| 73 |
+
if (p / "config.json").exists():
|
| 74 |
+
return True
|
| 75 |
+
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
|
| 76 |
+
return True
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _infer_target_modules(model) -> List[str]:
|
| 81 |
+
names = set()
|
| 82 |
+
for n, _ in model.named_modules():
|
| 83 |
+
names.add(n.split(".")[-1])
|
| 84 |
+
|
| 85 |
+
for group in [
|
| 86 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 87 |
+
["Wqkv", "out_proj"],
|
| 88 |
+
["query_key_value", "dense"],
|
| 89 |
+
["c_attn", "c_proj"],
|
| 90 |
+
]:
|
| 91 |
+
if all(x in names for x in group):
|
| 92 |
+
return group
|
| 93 |
+
|
| 94 |
+
fallback = [
|
| 95 |
+
x
|
| 96 |
+
for x in [
|
| 97 |
+
"q_proj",
|
| 98 |
+
"k_proj",
|
| 99 |
+
"v_proj",
|
| 100 |
+
"o_proj",
|
| 101 |
+
"c_attn",
|
| 102 |
+
"c_proj",
|
| 103 |
+
"out_proj",
|
| 104 |
+
"dense",
|
| 105 |
+
]
|
| 106 |
+
if x in names
|
| 107 |
+
]
|
| 108 |
+
if fallback:
|
| 109 |
+
return fallback
|
| 110 |
+
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 117 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# --------------------------
|
| 121 |
+
# Wandb Integration
|
| 122 |
+
# --------------------------
|
| 123 |
+
|
| 124 |
+
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
|
| 125 |
+
"""Initialize Wandb if enabled in configuration."""
|
| 126 |
+
wandb_cfg = cfg.get("wandb", {})
|
| 127 |
+
|
| 128 |
+
if not wandb_cfg.get("enabled", False):
|
| 129 |
+
print("Wandb logging disabled")
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
if not WANDB_AVAILABLE:
|
| 133 |
+
print("Wandb not available. Install with: pip install wandb")
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
# Extract wandb configuration
|
| 137 |
+
project = wandb_cfg.get("project", "sft-training")
|
| 138 |
+
entity = wandb_cfg.get("entity", None)
|
| 139 |
+
name = wandb_cfg.get("name", None)
|
| 140 |
+
tags = wandb_cfg.get("tags", [])
|
| 141 |
+
notes = wandb_cfg.get("notes", None)
|
| 142 |
+
|
| 143 |
+
# Initialize wandb
|
| 144 |
+
try:
|
| 145 |
+
wandb.init(
|
| 146 |
+
project=project,
|
| 147 |
+
entity=entity,
|
| 148 |
+
name=name,
|
| 149 |
+
tags=tags,
|
| 150 |
+
notes=notes,
|
| 151 |
+
dir=str(run_dir),
|
| 152 |
+
config={
|
| 153 |
+
"model": cfg.get("model", {}),
|
| 154 |
+
"data": cfg.get("data", {}),
|
| 155 |
+
"peft": cfg.get("peft", {}),
|
| 156 |
+
"train": cfg.get("train", {}),
|
| 157 |
+
"run_dir": str(run_dir),
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
|
| 161 |
+
return wandb
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"Failed to initialize Wandb: {e}")
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def finish_wandb():
|
| 168 |
+
"""Finish Wandb run if active."""
|
| 169 |
+
if WANDB_AVAILABLE and wandb.run is not None:
|
| 170 |
+
wandb.finish()
|
| 171 |
+
print("Wandb run finished")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# --------------------------
|
| 175 |
+
# JSONL Logger Callback
|
| 176 |
+
# --------------------------
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 180 |
+
def __init__(self, run_dir: Path):
|
| 181 |
+
self.run_dir = run_dir
|
| 182 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 183 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 184 |
+
self.start_time = None
|
| 185 |
+
|
| 186 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 187 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 188 |
+
return None
|
| 189 |
+
elapsed = time.time() - self.start_time
|
| 190 |
+
sec_per_step = elapsed / global_step
|
| 191 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 192 |
+
h = int(remaining // 3600)
|
| 193 |
+
m = int((remaining % 3600) // 60)
|
| 194 |
+
s = int(remaining % 60)
|
| 195 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 196 |
+
|
| 197 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 198 |
+
self.start_time = time.time()
|
| 199 |
+
|
| 200 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 201 |
+
if not logs:
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 205 |
+
progress_pct = (
|
| 206 |
+
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 207 |
+
)
|
| 208 |
+
epoch_pct = None
|
| 209 |
+
if (
|
| 210 |
+
state.epoch is not None
|
| 211 |
+
and args.num_train_epochs
|
| 212 |
+
and args.num_train_epochs > 0
|
| 213 |
+
):
|
| 214 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 215 |
+
|
| 216 |
+
payload = {
|
| 217 |
+
"ts": _now_iso(),
|
| 218 |
+
"event": "train_log",
|
| 219 |
+
"step": int(state.global_step),
|
| 220 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 221 |
+
"progress_pct": (
|
| 222 |
+
round(progress_pct, 2) if progress_pct is not None else None
|
| 223 |
+
),
|
| 224 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 225 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 226 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 227 |
+
**logs,
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 231 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 232 |
+
|
| 233 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 234 |
+
if not metrics:
|
| 235 |
+
return
|
| 236 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 237 |
+
ppl = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 238 |
+
|
| 239 |
+
payload = {
|
| 240 |
+
"ts": _now_iso(),
|
| 241 |
+
"event": "eval",
|
| 242 |
+
"step": int(state.global_step),
|
| 243 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 244 |
+
**metrics,
|
| 245 |
+
"perplexity": ppl,
|
| 246 |
+
}
|
| 247 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 248 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# --------------------------
|
| 252 |
+
# Data Pipeline (Instruction Formatting)
|
| 253 |
+
# --------------------------
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def format_instruction(
|
| 257 |
+
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
|
| 258 |
+
) -> Dict[str, Any]:
|
| 259 |
+
"""
|
| 260 |
+
Format instruction data for training.
|
| 261 |
+
Supports multiple formats: chatml, alpaca, custom templates.
|
| 262 |
+
Returns both formatted text and the response start position for loss masking.
|
| 263 |
+
"""
|
| 264 |
+
data_cfg = cfg["data"]
|
| 265 |
+
format_type = data_cfg.get("format_type", "chatml")
|
| 266 |
+
|
| 267 |
+
# Get field names from config
|
| 268 |
+
input_field = data_cfg.get("input_field", "input")
|
| 269 |
+
output_field = data_cfg.get("output_field", "output")
|
| 270 |
+
instruction_field = data_cfg.get("instruction_field", "instruction")
|
| 271 |
+
|
| 272 |
+
# Extract text from example
|
| 273 |
+
instruction = example.get(instruction_field, "")
|
| 274 |
+
input_text = example.get(input_field, "")
|
| 275 |
+
output_text = example.get(output_field, "")
|
| 276 |
+
|
| 277 |
+
if format_type == "chatml":
|
| 278 |
+
# ChatML format with special tokens
|
| 279 |
+
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
|
| 280 |
+
|
| 281 |
+
messages = []
|
| 282 |
+
if system_prompt:
|
| 283 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 284 |
+
|
| 285 |
+
user_content = instruction
|
| 286 |
+
if input_text:
|
| 287 |
+
user_content = f"{instruction}\n\n{input_text}"
|
| 288 |
+
messages.append({"role": "user", "content": user_content})
|
| 289 |
+
messages.append({"role": "assistant", "content": output_text})
|
| 290 |
+
|
| 291 |
+
# Apply chat template
|
| 292 |
+
formatted_text = tokenizer.apply_chat_template(
|
| 293 |
+
messages, tokenize=False, add_generation_prompt=False
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Add EOS token if not present
|
| 297 |
+
if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
|
| 298 |
+
formatted_text += tokenizer.eos_token
|
| 299 |
+
|
| 300 |
+
# Find where the assistant response starts for loss masking
|
| 301 |
+
# Try multiple possible markers for robustness
|
| 302 |
+
markers = ["<|im_start|>assistant", "<|assistant|>", "Assistant:", "assistant\n"]
|
| 303 |
+
response_start_pos = -1
|
| 304 |
+
|
| 305 |
+
for marker in markers:
|
| 306 |
+
idx = formatted_text.find(marker)
|
| 307 |
+
if idx != -1:
|
| 308 |
+
# Find the newline after the marker
|
| 309 |
+
newline_idx = formatted_text.find("\n", idx)
|
| 310 |
+
if newline_idx != -1:
|
| 311 |
+
response_start_pos = newline_idx + 1
|
| 312 |
+
break
|
| 313 |
+
|
| 314 |
+
# Fallback: find where the actual output starts
|
| 315 |
+
if response_start_pos == -1:
|
| 316 |
+
output_idx = formatted_text.find(output_text)
|
| 317 |
+
if output_idx != -1:
|
| 318 |
+
response_start_pos = output_idx
|
| 319 |
+
else:
|
| 320 |
+
# Last resort: split at last occurrence of newline before end
|
| 321 |
+
response_start_pos = formatted_text.rfind("\n", 0, len(formatted_text) - len(output_text)) + 1
|
| 322 |
+
|
| 323 |
+
elif format_type == "alpaca":
|
| 324 |
+
# Alpaca format
|
| 325 |
+
if input_text:
|
| 326 |
+
prefix = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
|
| 327 |
+
else:
|
| 328 |
+
prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
|
| 329 |
+
|
| 330 |
+
formatted_text = prefix + output_text
|
| 331 |
+
|
| 332 |
+
# Add EOS token
|
| 333 |
+
if tokenizer.eos_token:
|
| 334 |
+
formatted_text += tokenizer.eos_token
|
| 335 |
+
|
| 336 |
+
# Response starts after the prefix
|
| 337 |
+
response_start_pos = len(prefix)
|
| 338 |
+
|
| 339 |
+
elif format_type == "custom":
|
| 340 |
+
# Custom template from config
|
| 341 |
+
template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}")
|
| 342 |
+
|
| 343 |
+
# For custom format, use system_prompt as instruction if instruction field is empty
|
| 344 |
+
if not instruction:
|
| 345 |
+
instruction = data_cfg.get("system_prompt", "")
|
| 346 |
+
|
| 347 |
+
# For custom templates, we need to find where {output} starts
|
| 348 |
+
template_parts = template.split("{output}")
|
| 349 |
+
prefix = template_parts[0].format(instruction=instruction, input=input_text)
|
| 350 |
+
formatted_text = prefix + output_text
|
| 351 |
+
|
| 352 |
+
# Add EOS token if not already in template
|
| 353 |
+
if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
|
| 354 |
+
formatted_text += tokenizer.eos_token
|
| 355 |
+
|
| 356 |
+
# Response starts after the prefix
|
| 357 |
+
response_start_pos = len(prefix)
|
| 358 |
+
else:
|
| 359 |
+
raise ValueError(f"Unsupported format_type: {format_type}")
|
| 360 |
+
|
| 361 |
+
return {"text": formatted_text, "response_start_pos": response_start_pos}
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
|
| 365 |
+
"""
|
| 366 |
+
Build datasets for instruction fine-tuning.
|
| 367 |
+
"""
|
| 368 |
+
data_cfg = cfg["data"]
|
| 369 |
+
train_path = data_cfg["train_jsonl"]
|
| 370 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 371 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 372 |
+
max_length = int(data_cfg.get("max_length", 2048))
|
| 373 |
+
shuffle = bool(data_cfg.get("shuffle", True))
|
| 374 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 375 |
+
|
| 376 |
+
# Ensure tokenizer has pad token
|
| 377 |
+
if tokenizer.pad_token is None:
|
| 378 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 379 |
+
|
| 380 |
+
# Load datasets
|
| 381 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 382 |
+
|
| 383 |
+
if eval_path:
|
| 384 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 385 |
+
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
|
| 386 |
+
else:
|
| 387 |
+
if 0.0 < split_ratio < 1.0:
|
| 388 |
+
split = ds["train"].train_test_split(
|
| 389 |
+
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
|
| 390 |
+
)
|
| 391 |
+
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
|
| 392 |
+
else:
|
| 393 |
+
dsd = DatasetDict({"train": ds["train"], "eval": None})
|
| 394 |
+
|
| 395 |
+
# Format instructions and track response start positions
|
| 396 |
+
def format_fn(examples):
|
| 397 |
+
formatted_examples = []
|
| 398 |
+
response_start_positions = []
|
| 399 |
+
for i in range(len(examples[list(examples.keys())[0]])):
|
| 400 |
+
example = {k: examples[k][i] for k in examples.keys()}
|
| 401 |
+
formatted = format_instruction(example, cfg, tokenizer)
|
| 402 |
+
formatted_examples.append(formatted["text"])
|
| 403 |
+
response_start_positions.append(formatted["response_start_pos"])
|
| 404 |
+
return {
|
| 405 |
+
"text": formatted_examples,
|
| 406 |
+
"response_start_pos": response_start_positions
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
formatted_train = dsd["train"].map(
|
| 410 |
+
format_fn,
|
| 411 |
+
batched=True,
|
| 412 |
+
num_proc=num_proc,
|
| 413 |
+
remove_columns=dsd["train"].column_names,
|
| 414 |
+
desc="Formatting train instructions",
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
formatted_eval = None
|
| 418 |
+
if dsd["eval"] is not None:
|
| 419 |
+
formatted_eval = dsd["eval"].map(
|
| 420 |
+
format_fn,
|
| 421 |
+
batched=True,
|
| 422 |
+
num_proc=num_proc,
|
| 423 |
+
remove_columns=dsd["eval"].column_names,
|
| 424 |
+
desc="Formatting eval instructions",
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Tokenize and apply loss masking
|
| 428 |
+
def tokenize_and_mask_fn(examples):
|
| 429 |
+
tokenized = tokenizer(
|
| 430 |
+
examples["text"],
|
| 431 |
+
truncation=True,
|
| 432 |
+
padding=False,
|
| 433 |
+
max_length=max_length,
|
| 434 |
+
return_overflowing_tokens=False,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Apply loss masking - CRITICAL for SFT
|
| 438 |
+
labels = []
|
| 439 |
+
attention_masks = []
|
| 440 |
+
|
| 441 |
+
for i in range(len(tokenized["input_ids"])):
|
| 442 |
+
input_ids = tokenized["input_ids"][i]
|
| 443 |
+
response_start_pos = examples["response_start_pos"][i]
|
| 444 |
+
|
| 445 |
+
# Get the instruction part (before response)
|
| 446 |
+
full_text = examples["text"][i]
|
| 447 |
+
instruction_text = full_text[:response_start_pos]
|
| 448 |
+
|
| 449 |
+
# Create labels masked by default
|
| 450 |
+
label_ids = [-100] * len(input_ids)
|
| 451 |
+
|
| 452 |
+
# Find where response starts using character-based ratio
|
| 453 |
+
# This is more reliable than tokenizing prefix separately
|
| 454 |
+
# because separate tokenization can add different special tokens
|
| 455 |
+
char_ratio = response_start_pos / max(len(full_text), 1)
|
| 456 |
+
response_start_idx = int(len(input_ids) * char_ratio)
|
| 457 |
+
|
| 458 |
+
# Ensure we have valid bounds (at least position 1, at most len-1)
|
| 459 |
+
response_start_idx = max(1, min(response_start_idx, len(input_ids) - 1))
|
| 460 |
+
|
| 461 |
+
# Unmask response tokens (including EOS)
|
| 462 |
+
for j in range(response_start_idx, len(input_ids)):
|
| 463 |
+
label_ids[j] = input_ids[j]
|
| 464 |
+
|
| 465 |
+
# Create attention mask (1 for real tokens, 0 for padding)
|
| 466 |
+
attention_mask = [1] * len(input_ids)
|
| 467 |
+
|
| 468 |
+
labels.append(label_ids)
|
| 469 |
+
attention_masks.append(attention_mask)
|
| 470 |
+
|
| 471 |
+
tokenized["labels"] = labels
|
| 472 |
+
tokenized["attention_mask"] = attention_masks
|
| 473 |
+
return tokenized
|
| 474 |
+
|
| 475 |
+
tokenized_train = formatted_train.map(
|
| 476 |
+
tokenize_and_mask_fn,
|
| 477 |
+
batched=True,
|
| 478 |
+
num_proc=num_proc,
|
| 479 |
+
desc="Tokenizing and masking train",
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
tokenized_eval = None
|
| 483 |
+
if formatted_eval is not None:
|
| 484 |
+
tokenized_eval = formatted_eval.map(
|
| 485 |
+
tokenize_and_mask_fn,
|
| 486 |
+
batched=True,
|
| 487 |
+
num_proc=num_proc,
|
| 488 |
+
desc="Tokenizing and masking eval",
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if shuffle:
|
| 492 |
+
tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 493 |
+
|
| 494 |
+
return tokenized_train, tokenized_eval
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
# --------------------------
|
| 498 |
+
# Model Loading + PEFT
|
| 499 |
+
# --------------------------
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
|
| 503 |
+
model_cfg = cfg["model"]
|
| 504 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 505 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 506 |
+
device_map = model_cfg.get("device_map", "auto")
|
| 507 |
+
|
| 508 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 509 |
+
str(base_dir),
|
| 510 |
+
use_fast=use_fast,
|
| 511 |
+
trust_remote_code=trust_remote_code,
|
| 512 |
+
)
|
| 513 |
+
if tokenizer.pad_token is None:
|
| 514 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 515 |
+
|
| 516 |
+
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 517 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 518 |
+
|
| 519 |
+
quant_cfg = None
|
| 520 |
+
if use_4bit:
|
| 521 |
+
quant_cfg = BitsAndBytesConfig(
|
| 522 |
+
load_in_4bit=True,
|
| 523 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
|
| 524 |
+
bnb_4bit_use_double_quant=bool(
|
| 525 |
+
model_cfg.get("bnb_4bit_use_double_quant", True)
|
| 526 |
+
),
|
| 527 |
+
bnb_4bit_compute_dtype=_dtype_from_str(
|
| 528 |
+
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
|
| 529 |
+
),
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
attn_impl = _choose_attn_impl(cfg)
|
| 533 |
+
|
| 534 |
+
try:
|
| 535 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 536 |
+
str(base_dir),
|
| 537 |
+
device_map=device_map,
|
| 538 |
+
trust_remote_code=trust_remote_code,
|
| 539 |
+
low_cpu_mem_usage=True,
|
| 540 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 541 |
+
quantization_config=quant_cfg,
|
| 542 |
+
attn_implementation=attn_impl,
|
| 543 |
+
)
|
| 544 |
+
except Exception as e:
|
| 545 |
+
if attn_impl is not None:
|
| 546 |
+
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
|
| 547 |
+
print("[warn] Falling back to default attention implementation.")
|
| 548 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 549 |
+
str(base_dir),
|
| 550 |
+
device_map=device_map,
|
| 551 |
+
trust_remote_code=trust_remote_code,
|
| 552 |
+
low_cpu_mem_usage=True,
|
| 553 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 554 |
+
quantization_config=quant_cfg,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
return model, tokenizer
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def apply_peft(cfg: Dict[str, Any], model):
|
| 561 |
+
peft_cfg = cfg["peft"]
|
| 562 |
+
model_cfg = cfg["model"]
|
| 563 |
+
tr_cfg = cfg["train"]
|
| 564 |
+
|
| 565 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 566 |
+
return model, None
|
| 567 |
+
|
| 568 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 569 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 570 |
+
|
| 571 |
+
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
|
| 572 |
+
model.gradient_checkpointing_enable()
|
| 573 |
+
if hasattr(model, "config"):
|
| 574 |
+
model.config.use_cache = False
|
| 575 |
+
|
| 576 |
+
if use_4bit:
|
| 577 |
+
model = prepare_model_for_kbit_training(
|
| 578 |
+
model,
|
| 579 |
+
use_gradient_checkpointing=gradient_checkpointing,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 583 |
+
if target_modules == "auto":
|
| 584 |
+
target_modules = _infer_target_modules(model)
|
| 585 |
+
|
| 586 |
+
lora_config = LoraConfig(
|
| 587 |
+
r=int(peft_cfg.get("r", 16)),
|
| 588 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 589 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 590 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 591 |
+
task_type="CAUSAL_LM",
|
| 592 |
+
target_modules=target_modules,
|
| 593 |
+
)
|
| 594 |
+
model = get_peft_model(model, lora_config)
|
| 595 |
+
return model, lora_config
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
# --------------------------
|
| 599 |
+
# Merge Logic
|
| 600 |
+
# --------------------------
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def merge_adapter(
|
| 604 |
+
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
|
| 605 |
+
):
|
| 606 |
+
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
|
| 607 |
+
|
| 608 |
+
model_cfg = cfg["model"]
|
| 609 |
+
merge_cfg = cfg.get("merge", {})
|
| 610 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 611 |
+
|
| 612 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 613 |
+
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
|
| 614 |
+
|
| 615 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 616 |
+
str(base_dir),
|
| 617 |
+
torch_dtype=merged_dtype,
|
| 618 |
+
device_map="cpu",
|
| 619 |
+
low_cpu_mem_usage=True,
|
| 620 |
+
trust_remote_code=trust_remote_code,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
merged = PeftModel.from_pretrained(base, str(adapter_dir))
|
| 624 |
+
merged = merged.merge_and_unload()
|
| 625 |
+
|
| 626 |
+
_ensure_dir(final_dir)
|
| 627 |
+
merged.save_pretrained(
|
| 628 |
+
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
tok = AutoTokenizer.from_pretrained(
|
| 632 |
+
str(base_dir), trust_remote_code=trust_remote_code
|
| 633 |
+
)
|
| 634 |
+
if tok.pad_token is None:
|
| 635 |
+
tok.pad_token = tok.eos_token
|
| 636 |
+
tok.save_pretrained(str(final_dir))
|
| 637 |
+
|
| 638 |
+
print("--- Merge complete ---")
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
# --------------------------
|
| 642 |
+
# Main
|
| 643 |
+
# --------------------------
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def main():
|
| 647 |
+
ap = argparse.ArgumentParser()
|
| 648 |
+
ap.add_argument("--config", required=True, help="Path to YAML config")
|
| 649 |
+
ap.add_argument(
|
| 650 |
+
"--merge-only", action="store_true", help="Skip training, just merge adapter"
|
| 651 |
+
)
|
| 652 |
+
args = ap.parse_args()
|
| 653 |
+
|
| 654 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 655 |
+
cfg = yaml.safe_load(f)
|
| 656 |
+
|
| 657 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 658 |
+
_ensure_dir(run_dir / "logs")
|
| 659 |
+
|
| 660 |
+
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
|
| 661 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 662 |
+
|
| 663 |
+
model_cfg = cfg["model"]
|
| 664 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 665 |
+
repo_path = Path(repo_id)
|
| 666 |
+
|
| 667 |
+
# ✅ Local model path -> load directly; no download
|
| 668 |
+
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
|
| 669 |
+
base_dir = repo_path
|
| 670 |
+
print(f"Using local model at: {base_dir}")
|
| 671 |
+
elif repo_path.exists() and repo_path.is_dir():
|
| 672 |
+
raise ValueError(
|
| 673 |
+
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
|
| 674 |
+
)
|
| 675 |
+
else:
|
| 676 |
+
# HF repo_id -> download into run_dir/base_local_dir
|
| 677 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 678 |
+
if not _looks_like_model_dir(base_dir):
|
| 679 |
+
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
|
| 680 |
+
snapshot_download(
|
| 681 |
+
repo_id=repo_id,
|
| 682 |
+
revision=model_cfg.get("revision", None),
|
| 683 |
+
local_dir=str(base_dir),
|
| 684 |
+
local_dir_use_symlinks=False,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
|
| 688 |
+
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
|
| 689 |
+
|
| 690 |
+
merge_cfg = cfg.get("merge", {}) or {}
|
| 691 |
+
if merge_cfg.get("output_dir"):
|
| 692 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 693 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 694 |
+
else:
|
| 695 |
+
final_dir = run_dir / "final_model"
|
| 696 |
+
|
| 697 |
+
# Merge-only
|
| 698 |
+
if args.merge_only:
|
| 699 |
+
if not _looks_like_model_dir(best_adapter_dir):
|
| 700 |
+
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
|
| 701 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 702 |
+
return
|
| 703 |
+
|
| 704 |
+
# Initialize Wandb
|
| 705 |
+
wandb_run = setup_wandb(cfg, run_dir)
|
| 706 |
+
|
| 707 |
+
# Training
|
| 708 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 709 |
+
|
| 710 |
+
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
|
| 711 |
+
model, _ = apply_peft(cfg, model)
|
| 712 |
+
|
| 713 |
+
train_ds, eval_ds = build_datasets(cfg, tokenizer)
|
| 714 |
+
|
| 715 |
+
tr_cfg = cfg["train"]
|
| 716 |
+
|
| 717 |
+
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 718 |
+
use_fp16 = dtype == torch.float16
|
| 719 |
+
use_bf16 = dtype == torch.bfloat16
|
| 720 |
+
|
| 721 |
+
max_steps = int(tr_cfg.get("max_steps", 0))
|
| 722 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 723 |
+
|
| 724 |
+
# --- Dynamic evaluation strategy parameter handling ---
|
| 725 |
+
ta_params = inspect.signature(TrainingArguments.__init__).parameters
|
| 726 |
+
eval_key = (
|
| 727 |
+
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# Setup reporting based on wandb availability
|
| 731 |
+
report_to = []
|
| 732 |
+
if wandb_run is not None:
|
| 733 |
+
report_to.append("wandb")
|
| 734 |
+
|
| 735 |
+
ta_kwargs = dict(
|
| 736 |
+
output_dir=str(ckpt_dir),
|
| 737 |
+
max_steps=max_steps if max_steps > 0 else -1,
|
| 738 |
+
num_train_epochs=num_train_epochs,
|
| 739 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
|
| 740 |
+
per_device_eval_batch_size=int(
|
| 741 |
+
tr_cfg.get(
|
| 742 |
+
"per_device_eval_batch_size",
|
| 743 |
+
tr_cfg.get("per_device_train_batch_size", 1),
|
| 744 |
+
)
|
| 745 |
+
),
|
| 746 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
|
| 747 |
+
learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
|
| 748 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
|
| 749 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 750 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
|
| 751 |
+
optim=str(
|
| 752 |
+
tr_cfg.get(
|
| 753 |
+
"optim",
|
| 754 |
+
(
|
| 755 |
+
"paged_adamw_8bit"
|
| 756 |
+
if bool(model_cfg.get("use_4bit", False))
|
| 757 |
+
else "adamw_torch"
|
| 758 |
+
),
|
| 759 |
+
)
|
| 760 |
+
),
|
| 761 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
|
| 762 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 763 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 764 |
+
save_steps=int(tr_cfg.get("save_steps", 200)),
|
| 765 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 766 |
+
eval_steps=int(tr_cfg.get("eval_steps", 200)),
|
| 767 |
+
load_best_model_at_end=(
|
| 768 |
+
bool(tr_cfg.get("load_best_model_at_end", True))
|
| 769 |
+
if eval_ds is not None
|
| 770 |
+
else False
|
| 771 |
+
),
|
| 772 |
+
metric_for_best_model="eval_loss",
|
| 773 |
+
greater_is_better=False,
|
| 774 |
+
fp16=use_fp16,
|
| 775 |
+
bf16=use_bf16,
|
| 776 |
+
report_to=report_to,
|
| 777 |
+
remove_unused_columns=False,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
# Set the correct argument name for this transformers version
|
| 781 |
+
ta_kwargs[eval_key] = str(
|
| 782 |
+
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
training_args = TrainingArguments(**ta_kwargs)
|
| 786 |
+
|
| 787 |
+
# Setup callbacks
|
| 788 |
+
callbacks = [JsonlLoggerCallback(run_dir)]
|
| 789 |
+
|
| 790 |
+
# Add early stopping callback if enabled
|
| 791 |
+
early_stopping_cfg = tr_cfg.get("early_stopping", {})
|
| 792 |
+
if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
|
| 793 |
+
early_stopping_callback = EarlyStoppingCallback(
|
| 794 |
+
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
|
| 795 |
+
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
|
| 796 |
+
)
|
| 797 |
+
callbacks.append(early_stopping_callback)
|
| 798 |
+
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
|
| 799 |
+
f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
|
| 800 |
+
|
| 801 |
+
trainer = Trainer(
|
| 802 |
+
model=model,
|
| 803 |
+
args=training_args,
|
| 804 |
+
train_dataset=train_ds,
|
| 805 |
+
eval_dataset=eval_ds,
|
| 806 |
+
data_collator=default_data_collator,
|
| 807 |
+
callbacks=callbacks,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
# Resume
|
| 811 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 812 |
+
if resume_from == "auto":
|
| 813 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 814 |
+
resume_from = last if last else None
|
| 815 |
+
if resume_from:
|
| 816 |
+
print(f"Resuming from {resume_from}")
|
| 817 |
+
|
| 818 |
+
print("Starting instruction fine-tuning...")
|
| 819 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 820 |
+
|
| 821 |
+
trainer.save_model(str(best_adapter_dir))
|
| 822 |
+
print(f"Saved best adapter -> {best_adapter_dir}")
|
| 823 |
+
|
| 824 |
+
if eval_ds is not None:
|
| 825 |
+
metrics = trainer.evaluate()
|
| 826 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 827 |
+
metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 828 |
+
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
|
| 829 |
+
json.dump(metrics, f, indent=2)
|
| 830 |
+
print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
|
| 831 |
+
|
| 832 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 833 |
+
del trainer, model
|
| 834 |
+
torch.cuda.empty_cache()
|
| 835 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 836 |
+
else:
|
| 837 |
+
print("Merge disabled. Run with --merge-only later if needed.")
|
| 838 |
+
|
| 839 |
+
# Finish Wandb run
|
| 840 |
+
finish_wandb()
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
if __name__ == "__main__":
|
| 844 |
+
main()
|
trainer-kit/SFT/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
trainer-kit/SFT/config_instruct.yaml
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/instruct_run_24b"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# WandB integration for experiment tracking
|
| 6 |
+
wandb:
|
| 7 |
+
enabled: true # Set to true to enable wandb logging
|
| 8 |
+
project: "sft-training" # WandB project name
|
| 9 |
+
entity: null # WandB entity/team (optional)
|
| 10 |
+
name: null # Run name (optional, will auto-generate if null)
|
| 11 |
+
tags: ["sft-lora", "24b-Devstral"] # List of tags for the run (e.g., ["lora", "qlora", "experiment-1"])
|
| 12 |
+
notes: null # Run description/notes (optional)
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
# Use local Qwen2.5-Coder-14B model
|
| 16 |
+
repo_id: "./CPT/runs/cpt_run_v1/merged_24b_cpt_lora"
|
| 17 |
+
revision: null
|
| 18 |
+
|
| 19 |
+
# Used only when repo_id is a HF repo (not a local path)
|
| 20 |
+
base_local_dir: "base_model"
|
| 21 |
+
|
| 22 |
+
trust_remote_code: true
|
| 23 |
+
tokenizer_use_fast: true
|
| 24 |
+
device_map: "auto"
|
| 25 |
+
|
| 26 |
+
torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
|
| 27 |
+
|
| 28 |
+
# QLoRA
|
| 29 |
+
use_4bit: false
|
| 30 |
+
bnb_4bit_quant_type: "nf4"
|
| 31 |
+
bnb_4bit_use_double_quant: false
|
| 32 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 33 |
+
|
| 34 |
+
# optional: "flash_attention_2" | "sdpa" | null
|
| 35 |
+
attn_implementation: null
|
| 36 |
+
|
| 37 |
+
data:
|
| 38 |
+
train_jsonl: "../sft_dataset.jsonl"
|
| 39 |
+
eval_jsonl: null
|
| 40 |
+
eval_split_ratio: 0.1
|
| 41 |
+
|
| 42 |
+
# Field names in your JSONL data
|
| 43 |
+
instruction_field: "instruction" # This will be the system prompt
|
| 44 |
+
input_field: "input" # This is the task description
|
| 45 |
+
output_field: "output" # This is the analysis + selection
|
| 46 |
+
|
| 47 |
+
# Formatting options
|
| 48 |
+
format_type: "custom" # "chatml" | "alpaca" | "custom"
|
| 49 |
+
|
| 50 |
+
# For chatml format
|
| 51 |
+
system_prompt: |
|
| 52 |
+
You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 53 |
+
|
| 54 |
+
## Output Format
|
| 55 |
+
|
| 56 |
+
##OUTPUT
|
| 57 |
+
Explain the data flow and why each component must change:
|
| 58 |
+
- Flow: [Input → Processing → Output with arrows]
|
| 59 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
|
| 60 |
+
- Explain coupling between components
|
| 61 |
+
|
| 62 |
+
##SELECT
|
| 63 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 64 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 65 |
+
<EOS>
|
| 66 |
+
|
| 67 |
+
## Rules
|
| 68 |
+
|
| 69 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 70 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 71 |
+
3. Always explain "must change because" and "without this"
|
| 72 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 73 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 74 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 75 |
+
|
| 76 |
+
## Example
|
| 77 |
+
|
| 78 |
+
##TASK
|
| 79 |
+
Add webhook subscription support
|
| 80 |
+
|
| 81 |
+
##OUTPUT
|
| 82 |
+
The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings.
|
| 83 |
+
|
| 84 |
+
##SELECT
|
| 85 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 86 |
+
crates/common_enums/src/transformers.rs::SubscriptionStatus
|
| 87 |
+
<EOS>
|
| 88 |
+
|
| 89 |
+
# For custom format (only used when format_type="custom")
|
| 90 |
+
custom_template: "##INSTRUCTION\n{instruction}<|im_end|>\n##TASK\n{input}<|im_end|>\n##OUTPUT\n{output}<|im_end|>"
|
| 91 |
+
|
| 92 |
+
max_length: 2048
|
| 93 |
+
shuffle: true
|
| 94 |
+
num_proc: 4
|
| 95 |
+
|
| 96 |
+
peft:
|
| 97 |
+
enabled: true
|
| 98 |
+
r: 8
|
| 99 |
+
lora_alpha: 16
|
| 100 |
+
lora_dropout: 0.05
|
| 101 |
+
bias: "none"
|
| 102 |
+
target_modules: "auto"
|
| 103 |
+
|
| 104 |
+
train:
|
| 105 |
+
# max_steps: 10
|
| 106 |
+
num_train_epochs: 6
|
| 107 |
+
|
| 108 |
+
per_device_train_batch_size: 1
|
| 109 |
+
per_device_eval_batch_size: 1
|
| 110 |
+
gradient_accumulation_steps: 8
|
| 111 |
+
|
| 112 |
+
learning_rate: 1e-4
|
| 113 |
+
weight_decay: 0.0
|
| 114 |
+
warmup_ratio: 0.08
|
| 115 |
+
lr_scheduler_type: "cosine"
|
| 116 |
+
|
| 117 |
+
optim: "adamw_torch" # ✅ Changed from paged_adamw_8bit (requires use_4bit=true)
|
| 118 |
+
max_grad_norm: 0.8
|
| 119 |
+
gradient_checkpointing: true
|
| 120 |
+
|
| 121 |
+
logging_steps: 2
|
| 122 |
+
save_strategy: "steps"
|
| 123 |
+
save_steps: 500
|
| 124 |
+
save_total_limit: 20
|
| 125 |
+
|
| 126 |
+
evaluation_strategy: "steps"
|
| 127 |
+
eval_steps: 100
|
| 128 |
+
load_best_model_at_end: true
|
| 129 |
+
|
| 130 |
+
# Early stopping
|
| 131 |
+
early_stopping:
|
| 132 |
+
enabled: true
|
| 133 |
+
patience: 3 # Number of evaluations with no improvement before stopping
|
| 134 |
+
min_delta: 0.001 # Minimum change to qualify as improvement
|
| 135 |
+
metric: "eval_loss" # Metric to monitor
|
| 136 |
+
mode: "min" # "min" for loss, "max" for accuracy/etc.
|
| 137 |
+
|
| 138 |
+
resume_from_checkpoint: "auto"
|
| 139 |
+
|
| 140 |
+
merge:
|
| 141 |
+
enabled: true
|
| 142 |
+
merged_dtype: "float16"
|
| 143 |
+
max_shard_size: "2GB"
|
| 144 |
+
output_dir: "./merged_24b_instruct_lora"
|
trainer-kit/SFT/instruct_data.jsonl
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"input": "##TASK\nAdd subscription support to webhook event metadata and outgoing webhook content systems to enable proper subscription lifecycle event handling and webhook delivery for subscription-related operations.", "output": "##OUTPUT\nThe webhook system requires subscription event support to handle subscription lifecycle events (creation, payment, cancellation) and deliver them to merchants via outgoing webhooks. This affects the event metadata enumeration and webhook content transformation layers. The complete flow is: subscription event occurs → EventMetadata::Subscription variant created with subscription_id, invoice_id, payment_id → stored in database → retrieved for webhook delivery → OutgoingWebhookEventContent::foreign_from converts EventMetadata to OutgoingWebhookEventContent::Subscription → StripeWebhookObject::From transforms to StripeWebhookObject::Subscriptions for Stripe compatibility. The OutgoingWebhookEventContent::foreign_from function (router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from) must add the Subscription variant match arm because it's the bridge from storage EventMetadata to webhook content representation—without this, subscription events cannot be converted to webhook content and would be silently dropped during webhook processing. The StripeWebhookObject::From implementation (router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From) must handle OutgoingWebhookContent::SubscriptionDetails because it provides Stripe-compatible webhook formatting—without this, merchants using Stripe webhook endpoints would receive malformed or missing subscription events. The EventMetadata enum (implicitly referenced in foreign_from) requires the Subscription variant to exist in the diesel_models layer, which enables persistent storage of subscription event references. These components form a tightly coupled webhook pipeline where each depends on the previous: EventMetadata → OutgoingWebhookEventContent → StripeWebhookObject. If foreign_from lacks the Subscription match arm, subscription events stored as EventMetadata::Subscription would cause panic during webhook delivery. If StripeWebhookObject::From lacks SubscriptionDetails handling, Stripe webhook endpoints would receive unparseable responses, breaking merchant integrations that rely on Stripe-compatible subscription webhooks for revenue recognition and customer management systems.\n\n##SELECT\ncrates/router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from\ncrates/router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From\n<EOS>"}
|
| 2 |
+
{"input": "##TASK\nUpdate subscription response mapping to use common_enums::SubscriptionStatus instead of subscription_types::SubscriptionStatus for consistent enum usage across the codebase.", "output": "##OUTPUT\nThe subscription system needs to standardize on a single SubscriptionStatus enum type to avoid type inconsistencies and import conflicts. The current flow is: subscription handler creates SubscriptionResponse → calls to_subscription_response method → converts string status to subscription_types::SubscriptionStatus → returns response. The to_subscription_response method (crates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response) must change because it's the conversion point where the status enum is instantiated. Using common_enums::SubscriptionStatus ensures consistency with other parts of the system that already use this common enum type. Without this change, the subscription module would maintain its own duplicate enum, creating potential mismatches and requiring conversions throughout the codebase when interacting with other modules that expect the common enum type.\n\n##SELECT\ncrates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response\n<EOS>"}
|
| 3 |
+
{"input": "##TASK\nAdd subscription support to webhook event system by including Subscriptions in EventClass enum and implementing SubscriptionStatus to EventType mapping for proper webhook routing.", "output": "##OUTPUT\nThe webhook system needs to handle subscription-related events for proper event classification and routing. The current flow is: webhook event received → EventClass determines routing category → EventType specifies the exact event type → handlers process based on event type. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must include Subscriptions because it's the primary classification that determines which subsystem handles the event. The SubscriptionStatus transformer (crates/common_enums/src/transformers.rs::impl::Option<EventType>) must map subscription status changes to appropriate EventType values because it converts domain state changes into webhook events. Without EventClass::Subscriptions, subscription events would be misclassified or unhandled. Without the SubscriptionStatus mapping, subscription status changes wouldn't trigger appropriate webhook notifications, breaking integration with external systems that depend on subscription lifecycle events.\n\n##SELECT\ncrates/common_enums/src/enums.rs::EventClass\ncrates/common_enums/src/transformers.rs::impl::Option<EventType>\n<EOS>"}
|
| 4 |
+
{"input": "##TASK\nAdd subscription ID support to webhook object reference system for proper webhook payload identification and routing.", "output": "##OUTPUT\nThe webhook system needs to reference subscription entities in event payloads for proper event correlation and processing. The current flow is: webhook event generated → ObjectReferenceId identifies the affected entity → webhook payload includes reference → consumers process based on entity type. The ObjectReferenceId enum (crates/api_models/src/webhooks.rs::ObjectReferenceId) must include SubscriptionId because it's the type-safe identifier used throughout the webhook payload structure to specify which subscription triggered the event. Without SubscriptionId, webhook events related to subscriptions couldn't properly reference the subscription entity, making it impossible for consumers to correlate events with specific subscriptions. This would break webhook consumers that need to update their local state or trigger business logic based on subscription events.\n\n##SELECT\ncrates/api_models/src/webhooks.rs::ObjectReferenceId\n<EOS>"}
|
trainer-kit/SFT/requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
transformers>=4.41.0
|
| 4 |
+
datasets>=2.18.0
|
| 5 |
+
accelerate>=0.30.0
|
| 6 |
+
|
| 7 |
+
# PEFT / QLoRA
|
| 8 |
+
peft>=0.11.1
|
| 9 |
+
bitsandbytes>=0.43.1
|
| 10 |
+
|
| 11 |
+
# Hugging Face Hub (local + download support)
|
| 12 |
+
huggingface_hub>=0.23.0
|
| 13 |
+
|
| 14 |
+
# Config + utilities
|
| 15 |
+
pyyaml>=6.0
|
| 16 |
+
tqdm>=4.66.0
|
| 17 |
+
|
| 18 |
+
# Optional but recommended (tokenizers speed)
|
| 19 |
+
tokenizers>=0.15.0
|
| 20 |
+
safetensors>=0.4.2
|
| 21 |
+
|
| 22 |
+
# Experiment tracking
|
| 23 |
+
wandb>=0.16.0
|
trainer-kit/SFT/run_instruct.py
ADDED
|
@@ -0,0 +1,921 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import inspect # Added for Transformers version compatibility
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import yaml
|
| 11 |
+
from datasets import load_dataset, DatasetDict
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForCausalLM,
|
| 16 |
+
AutoModel,
|
| 17 |
+
AutoConfig,
|
| 18 |
+
BitsAndBytesConfig,
|
| 19 |
+
TrainingArguments,
|
| 20 |
+
Trainer,
|
| 21 |
+
TrainerCallback,
|
| 22 |
+
EarlyStoppingCallback,
|
| 23 |
+
default_data_collator,
|
| 24 |
+
set_seed,
|
| 25 |
+
)
|
| 26 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 27 |
+
from peft import (
|
| 28 |
+
LoraConfig,
|
| 29 |
+
get_peft_model,
|
| 30 |
+
prepare_model_for_kbit_training,
|
| 31 |
+
PeftModel,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import wandb
|
| 36 |
+
WANDB_AVAILABLE = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
WANDB_AVAILABLE = False
|
| 39 |
+
wandb = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# --------------------------
|
| 43 |
+
# Helpers
|
| 44 |
+
# --------------------------
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 48 |
+
s = (s or "").lower()
|
| 49 |
+
if s in ("float16", "fp16"):
|
| 50 |
+
return torch.float16
|
| 51 |
+
if s in ("bfloat16", "bf16"):
|
| 52 |
+
return torch.bfloat16
|
| 53 |
+
if s in ("float32", "fp32"):
|
| 54 |
+
return torch.float32
|
| 55 |
+
raise ValueError(f"Unknown torch_dtype: {s}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _now_iso() -> str:
|
| 59 |
+
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _safe_exp(x: float) -> float:
|
| 63 |
+
x = min(float(x), 50.0)
|
| 64 |
+
return float(math.exp(x))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _ensure_dir(p: Path) -> Path:
|
| 68 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
return p
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _looks_like_model_dir(p: Path) -> bool:
|
| 73 |
+
if not p.exists() or not p.is_dir():
|
| 74 |
+
return False
|
| 75 |
+
if (p / "config.json").exists():
|
| 76 |
+
return True
|
| 77 |
+
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
|
| 78 |
+
return True
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _infer_target_modules(model) -> List[str]:
|
| 83 |
+
names = set()
|
| 84 |
+
for n, _ in model.named_modules():
|
| 85 |
+
names.add(n.split(".")[-1])
|
| 86 |
+
|
| 87 |
+
for group in [
|
| 88 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 89 |
+
["Wqkv", "out_proj"],
|
| 90 |
+
["query_key_value", "dense"],
|
| 91 |
+
["c_attn", "c_proj"],
|
| 92 |
+
]:
|
| 93 |
+
if all(x in names for x in group):
|
| 94 |
+
return group
|
| 95 |
+
|
| 96 |
+
fallback = [
|
| 97 |
+
x
|
| 98 |
+
for x in [
|
| 99 |
+
"q_proj",
|
| 100 |
+
"k_proj",
|
| 101 |
+
"v_proj",
|
| 102 |
+
"o_proj",
|
| 103 |
+
"c_attn",
|
| 104 |
+
"c_proj",
|
| 105 |
+
"out_proj",
|
| 106 |
+
"dense",
|
| 107 |
+
]
|
| 108 |
+
if x in names
|
| 109 |
+
]
|
| 110 |
+
if fallback:
|
| 111 |
+
return fallback
|
| 112 |
+
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 119 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# --------------------------
|
| 123 |
+
# Wandb Integration
|
| 124 |
+
# --------------------------
|
| 125 |
+
|
| 126 |
+
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
|
| 127 |
+
"""Initialize Wandb if enabled in configuration."""
|
| 128 |
+
wandb_cfg = cfg.get("wandb", {})
|
| 129 |
+
|
| 130 |
+
if not wandb_cfg.get("enabled", False):
|
| 131 |
+
print("Wandb logging disabled")
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
if not WANDB_AVAILABLE:
|
| 135 |
+
print("Wandb not available. Install with: pip install wandb")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
# Extract wandb configuration
|
| 139 |
+
project = wandb_cfg.get("project", "sft-training")
|
| 140 |
+
entity = wandb_cfg.get("entity", None)
|
| 141 |
+
name = wandb_cfg.get("name", None)
|
| 142 |
+
tags = wandb_cfg.get("tags", [])
|
| 143 |
+
notes = wandb_cfg.get("notes", None)
|
| 144 |
+
|
| 145 |
+
# Initialize wandb
|
| 146 |
+
try:
|
| 147 |
+
wandb.init(
|
| 148 |
+
project=project,
|
| 149 |
+
entity=entity,
|
| 150 |
+
name=name,
|
| 151 |
+
tags=tags,
|
| 152 |
+
notes=notes,
|
| 153 |
+
dir=str(run_dir),
|
| 154 |
+
config={
|
| 155 |
+
"model": cfg.get("model", {}),
|
| 156 |
+
"data": cfg.get("data", {}),
|
| 157 |
+
"peft": cfg.get("peft", {}),
|
| 158 |
+
"train": cfg.get("train", {}),
|
| 159 |
+
"run_dir": str(run_dir),
|
| 160 |
+
}
|
| 161 |
+
)
|
| 162 |
+
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
|
| 163 |
+
return wandb
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Failed to initialize Wandb: {e}")
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def finish_wandb():
|
| 170 |
+
"""Finish Wandb run if active."""
|
| 171 |
+
if WANDB_AVAILABLE and wandb.run is not None:
|
| 172 |
+
wandb.finish()
|
| 173 |
+
print("Wandb run finished")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# --------------------------
|
| 177 |
+
# JSONL Logger Callback
|
| 178 |
+
# --------------------------
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 182 |
+
def __init__(self, run_dir: Path):
|
| 183 |
+
self.run_dir = run_dir
|
| 184 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 185 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 186 |
+
self.start_time = None
|
| 187 |
+
|
| 188 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 189 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 190 |
+
return None
|
| 191 |
+
elapsed = time.time() - self.start_time
|
| 192 |
+
sec_per_step = elapsed / global_step
|
| 193 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 194 |
+
h = int(remaining // 3600)
|
| 195 |
+
m = int((remaining % 3600) // 60)
|
| 196 |
+
s = int(remaining % 60)
|
| 197 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 198 |
+
|
| 199 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 200 |
+
self.start_time = time.time()
|
| 201 |
+
|
| 202 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 203 |
+
if not logs:
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 207 |
+
progress_pct = (
|
| 208 |
+
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 209 |
+
)
|
| 210 |
+
epoch_pct = None
|
| 211 |
+
if (
|
| 212 |
+
state.epoch is not None
|
| 213 |
+
and args.num_train_epochs
|
| 214 |
+
and args.num_train_epochs > 0
|
| 215 |
+
):
|
| 216 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 217 |
+
|
| 218 |
+
payload = {
|
| 219 |
+
"ts": _now_iso(),
|
| 220 |
+
"event": "train_log",
|
| 221 |
+
"step": int(state.global_step),
|
| 222 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 223 |
+
"progress_pct": (
|
| 224 |
+
round(progress_pct, 2) if progress_pct is not None else None
|
| 225 |
+
),
|
| 226 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 227 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 228 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 229 |
+
**logs,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 233 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 234 |
+
|
| 235 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 236 |
+
if not metrics:
|
| 237 |
+
return
|
| 238 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 239 |
+
ppl = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 240 |
+
|
| 241 |
+
payload = {
|
| 242 |
+
"ts": _now_iso(),
|
| 243 |
+
"event": "eval",
|
| 244 |
+
"step": int(state.global_step),
|
| 245 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 246 |
+
**metrics,
|
| 247 |
+
"perplexity": ppl,
|
| 248 |
+
}
|
| 249 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 250 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# --------------------------
|
| 254 |
+
# Data Pipeline (Instruction Formatting)
|
| 255 |
+
# --------------------------
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def format_instruction(
|
| 259 |
+
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
|
| 260 |
+
) -> Dict[str, Any]:
|
| 261 |
+
"""
|
| 262 |
+
Format instruction data for training.
|
| 263 |
+
Supports multiple formats: chatml, alpaca, custom templates.
|
| 264 |
+
Returns both formatted text and the response start position for loss masking.
|
| 265 |
+
"""
|
| 266 |
+
data_cfg = cfg["data"]
|
| 267 |
+
format_type = data_cfg.get("format_type", "chatml")
|
| 268 |
+
|
| 269 |
+
# Get field names from config
|
| 270 |
+
input_field = data_cfg.get("input_field", "input")
|
| 271 |
+
output_field = data_cfg.get("output_field", "output")
|
| 272 |
+
instruction_field = data_cfg.get("instruction_field", "instruction")
|
| 273 |
+
|
| 274 |
+
# Extract text from example
|
| 275 |
+
instruction = example.get(instruction_field, "")
|
| 276 |
+
input_text = example.get(input_field, "")
|
| 277 |
+
output_text = example.get(output_field, "")
|
| 278 |
+
|
| 279 |
+
if format_type == "chatml":
|
| 280 |
+
# ChatML format with special tokens
|
| 281 |
+
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
|
| 282 |
+
|
| 283 |
+
messages = []
|
| 284 |
+
if system_prompt:
|
| 285 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 286 |
+
|
| 287 |
+
user_content = instruction
|
| 288 |
+
if input_text:
|
| 289 |
+
user_content = f"{instruction}\n\n{input_text}"
|
| 290 |
+
messages.append({"role": "user", "content": user_content})
|
| 291 |
+
messages.append({"role": "assistant", "content": output_text})
|
| 292 |
+
|
| 293 |
+
# Apply chat template
|
| 294 |
+
formatted_text = tokenizer.apply_chat_template(
|
| 295 |
+
messages, tokenize=False, add_generation_prompt=False
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Add EOS token if not present
|
| 299 |
+
if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
|
| 300 |
+
formatted_text += tokenizer.eos_token
|
| 301 |
+
|
| 302 |
+
# Find where the assistant response starts for loss masking
|
| 303 |
+
# Try multiple possible markers for robustness
|
| 304 |
+
markers = ["<|im_start|>assistant", "<|assistant|>", "Assistant:", "assistant\n"]
|
| 305 |
+
response_start_pos = -1
|
| 306 |
+
|
| 307 |
+
for marker in markers:
|
| 308 |
+
idx = formatted_text.find(marker)
|
| 309 |
+
if idx != -1:
|
| 310 |
+
# Find the newline after the marker
|
| 311 |
+
newline_idx = formatted_text.find("\n", idx)
|
| 312 |
+
if newline_idx != -1:
|
| 313 |
+
response_start_pos = newline_idx + 1
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
# Fallback: find where the actual output starts
|
| 317 |
+
if response_start_pos == -1:
|
| 318 |
+
output_idx = formatted_text.find(output_text)
|
| 319 |
+
if output_idx != -1:
|
| 320 |
+
response_start_pos = output_idx
|
| 321 |
+
else:
|
| 322 |
+
# Last resort: split at last occurrence of newline before end
|
| 323 |
+
response_start_pos = formatted_text.rfind("\n", 0, len(formatted_text) - len(output_text)) + 1
|
| 324 |
+
|
| 325 |
+
elif format_type == "alpaca":
|
| 326 |
+
# Alpaca format
|
| 327 |
+
if input_text:
|
| 328 |
+
prefix = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
|
| 329 |
+
else:
|
| 330 |
+
prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
|
| 331 |
+
|
| 332 |
+
formatted_text = prefix + output_text
|
| 333 |
+
|
| 334 |
+
# Add EOS token
|
| 335 |
+
if tokenizer.eos_token:
|
| 336 |
+
formatted_text += tokenizer.eos_token
|
| 337 |
+
|
| 338 |
+
# Response starts after the prefix
|
| 339 |
+
response_start_pos = len(prefix)
|
| 340 |
+
|
| 341 |
+
elif format_type == "custom":
|
| 342 |
+
# Custom template from config
|
| 343 |
+
template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}")
|
| 344 |
+
|
| 345 |
+
# For custom format, use system_prompt as instruction if instruction field is empty
|
| 346 |
+
if not instruction:
|
| 347 |
+
instruction = data_cfg.get("system_prompt", "")
|
| 348 |
+
|
| 349 |
+
# For custom templates, we need to find where {output} starts
|
| 350 |
+
template_parts = template.split("{output}")
|
| 351 |
+
prefix = template_parts[0].format(instruction=instruction, input=input_text)
|
| 352 |
+
formatted_text = prefix + output_text
|
| 353 |
+
|
| 354 |
+
# Add EOS token if not already in template
|
| 355 |
+
if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
|
| 356 |
+
formatted_text += tokenizer.eos_token
|
| 357 |
+
|
| 358 |
+
# Response starts after the prefix
|
| 359 |
+
response_start_pos = len(prefix)
|
| 360 |
+
else:
|
| 361 |
+
raise ValueError(f"Unsupported format_type: {format_type}")
|
| 362 |
+
|
| 363 |
+
return {"text": formatted_text, "response_start_pos": response_start_pos}
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
|
| 367 |
+
"""
|
| 368 |
+
Build datasets for instruction fine-tuning.
|
| 369 |
+
"""
|
| 370 |
+
data_cfg = cfg["data"]
|
| 371 |
+
train_path = data_cfg["train_jsonl"]
|
| 372 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 373 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 374 |
+
max_length = int(data_cfg.get("max_length", 2048))
|
| 375 |
+
shuffle = bool(data_cfg.get("shuffle", True))
|
| 376 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 377 |
+
|
| 378 |
+
# Ensure tokenizer has pad token
|
| 379 |
+
if tokenizer.pad_token is None:
|
| 380 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 381 |
+
|
| 382 |
+
# Load datasets
|
| 383 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 384 |
+
|
| 385 |
+
if eval_path:
|
| 386 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 387 |
+
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
|
| 388 |
+
else:
|
| 389 |
+
if 0.0 < split_ratio < 1.0:
|
| 390 |
+
split = ds["train"].train_test_split(
|
| 391 |
+
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
|
| 392 |
+
)
|
| 393 |
+
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
|
| 394 |
+
else:
|
| 395 |
+
dsd = DatasetDict({"train": ds["train"], "eval": None})
|
| 396 |
+
|
| 397 |
+
# Format instructions and track response start positions
|
| 398 |
+
def format_fn(examples):
|
| 399 |
+
formatted_examples = []
|
| 400 |
+
response_start_positions = []
|
| 401 |
+
for i in range(len(examples[list(examples.keys())[0]])):
|
| 402 |
+
example = {k: examples[k][i] for k in examples.keys()}
|
| 403 |
+
formatted = format_instruction(example, cfg, tokenizer)
|
| 404 |
+
formatted_examples.append(formatted["text"])
|
| 405 |
+
response_start_positions.append(formatted["response_start_pos"])
|
| 406 |
+
return {
|
| 407 |
+
"text": formatted_examples,
|
| 408 |
+
"response_start_pos": response_start_positions
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
formatted_train = dsd["train"].map(
|
| 412 |
+
format_fn,
|
| 413 |
+
batched=True,
|
| 414 |
+
num_proc=num_proc,
|
| 415 |
+
remove_columns=dsd["train"].column_names,
|
| 416 |
+
desc="Formatting train instructions",
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
formatted_eval = None
|
| 420 |
+
if dsd["eval"] is not None:
|
| 421 |
+
formatted_eval = dsd["eval"].map(
|
| 422 |
+
format_fn,
|
| 423 |
+
batched=True,
|
| 424 |
+
num_proc=num_proc,
|
| 425 |
+
remove_columns=dsd["eval"].column_names,
|
| 426 |
+
desc="Formatting eval instructions",
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Tokenize and apply loss masking
|
| 430 |
+
def tokenize_and_mask_fn(examples):
|
| 431 |
+
tokenized = tokenizer(
|
| 432 |
+
examples["text"],
|
| 433 |
+
truncation=True,
|
| 434 |
+
padding=False,
|
| 435 |
+
max_length=max_length,
|
| 436 |
+
return_overflowing_tokens=False,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Apply loss masking - CRITICAL for SFT
|
| 440 |
+
labels = []
|
| 441 |
+
attention_masks = []
|
| 442 |
+
|
| 443 |
+
for i in range(len(tokenized["input_ids"])):
|
| 444 |
+
input_ids = tokenized["input_ids"][i]
|
| 445 |
+
response_start_pos = examples["response_start_pos"][i]
|
| 446 |
+
|
| 447 |
+
# Get the instruction part (before response)
|
| 448 |
+
full_text = examples["text"][i]
|
| 449 |
+
instruction_text = full_text[:response_start_pos]
|
| 450 |
+
|
| 451 |
+
# Create labels masked by default
|
| 452 |
+
label_ids = [-100] * len(input_ids)
|
| 453 |
+
|
| 454 |
+
# Find where response starts using character-based ratio
|
| 455 |
+
# This is more reliable than tokenizing prefix separately
|
| 456 |
+
# because separate tokenization can add different special tokens
|
| 457 |
+
char_ratio = response_start_pos / max(len(full_text), 1)
|
| 458 |
+
response_start_idx = int(len(input_ids) * char_ratio)
|
| 459 |
+
|
| 460 |
+
# Ensure we have valid bounds (at least position 1, at most len-1)
|
| 461 |
+
response_start_idx = max(1, min(response_start_idx, len(input_ids) - 1))
|
| 462 |
+
|
| 463 |
+
# Unmask response tokens (including EOS)
|
| 464 |
+
for j in range(response_start_idx, len(input_ids)):
|
| 465 |
+
label_ids[j] = input_ids[j]
|
| 466 |
+
|
| 467 |
+
# Create attention mask (1 for real tokens, 0 for padding)
|
| 468 |
+
attention_mask = [1] * len(input_ids)
|
| 469 |
+
|
| 470 |
+
labels.append(label_ids)
|
| 471 |
+
attention_masks.append(attention_mask)
|
| 472 |
+
|
| 473 |
+
tokenized["labels"] = labels
|
| 474 |
+
tokenized["attention_mask"] = attention_masks
|
| 475 |
+
return tokenized
|
| 476 |
+
|
| 477 |
+
tokenized_train = formatted_train.map(
|
| 478 |
+
tokenize_and_mask_fn,
|
| 479 |
+
batched=True,
|
| 480 |
+
num_proc=num_proc,
|
| 481 |
+
desc="Tokenizing and masking train",
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
tokenized_eval = None
|
| 485 |
+
if formatted_eval is not None:
|
| 486 |
+
tokenized_eval = formatted_eval.map(
|
| 487 |
+
tokenize_and_mask_fn,
|
| 488 |
+
batched=True,
|
| 489 |
+
num_proc=num_proc,
|
| 490 |
+
desc="Tokenizing and masking eval",
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
if shuffle:
|
| 494 |
+
tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 495 |
+
|
| 496 |
+
return tokenized_train, tokenized_eval
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
# --------------------------
|
| 500 |
+
# Model Loading + PEFT
|
| 501 |
+
# --------------------------
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
|
| 505 |
+
model_cfg = cfg["model"]
|
| 506 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 507 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 508 |
+
device_map = model_cfg.get("device_map", "auto")
|
| 509 |
+
|
| 510 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 511 |
+
str(base_dir),
|
| 512 |
+
use_fast=use_fast,
|
| 513 |
+
trust_remote_code=trust_remote_code,
|
| 514 |
+
)
|
| 515 |
+
if tokenizer.pad_token is None:
|
| 516 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 517 |
+
|
| 518 |
+
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 519 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 520 |
+
|
| 521 |
+
quant_cfg = None
|
| 522 |
+
if use_4bit:
|
| 523 |
+
quant_cfg = BitsAndBytesConfig(
|
| 524 |
+
load_in_4bit=True,
|
| 525 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
|
| 526 |
+
bnb_4bit_use_double_quant=bool(
|
| 527 |
+
model_cfg.get("bnb_4bit_use_double_quant", True)
|
| 528 |
+
),
|
| 529 |
+
bnb_4bit_compute_dtype=_dtype_from_str(
|
| 530 |
+
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
|
| 531 |
+
),
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
attn_impl = _choose_attn_impl(cfg)
|
| 535 |
+
|
| 536 |
+
# First check the model type to determine loading strategy
|
| 537 |
+
try:
|
| 538 |
+
config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True)
|
| 539 |
+
model_type = config.model_type
|
| 540 |
+
architectures = getattr(config, 'architectures', [])
|
| 541 |
+
|
| 542 |
+
# Handle Mistral3 (multimodal) models
|
| 543 |
+
if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]):
|
| 544 |
+
print(f"[info] Detected Mistral3 model architecture, loading with specific class")
|
| 545 |
+
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
|
| 546 |
+
|
| 547 |
+
try:
|
| 548 |
+
model = Mistral3ForConditionalGeneration.from_pretrained(
|
| 549 |
+
str(base_dir),
|
| 550 |
+
config=config,
|
| 551 |
+
device_map=device_map,
|
| 552 |
+
low_cpu_mem_usage=True,
|
| 553 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 554 |
+
quantization_config=quant_cfg,
|
| 555 |
+
attn_implementation=attn_impl,
|
| 556 |
+
)
|
| 557 |
+
except Exception as e:
|
| 558 |
+
if attn_impl is not None:
|
| 559 |
+
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
|
| 560 |
+
print("[warn] Falling back to default attention implementation.")
|
| 561 |
+
model = Mistral3ForConditionalGeneration.from_pretrained(
|
| 562 |
+
str(base_dir),
|
| 563 |
+
config=config,
|
| 564 |
+
device_map=device_map,
|
| 565 |
+
low_cpu_mem_usage=True,
|
| 566 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 567 |
+
quantization_config=quant_cfg,
|
| 568 |
+
)
|
| 569 |
+
else:
|
| 570 |
+
raise e
|
| 571 |
+
else:
|
| 572 |
+
# Standard AutoModelForCausalLM loading for other models
|
| 573 |
+
try:
|
| 574 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 575 |
+
str(base_dir),
|
| 576 |
+
device_map=device_map,
|
| 577 |
+
trust_remote_code=True,
|
| 578 |
+
low_cpu_mem_usage=True,
|
| 579 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 580 |
+
quantization_config=quant_cfg,
|
| 581 |
+
attn_implementation=attn_impl,
|
| 582 |
+
)
|
| 583 |
+
except Exception as e:
|
| 584 |
+
if attn_impl is not None:
|
| 585 |
+
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
|
| 586 |
+
print("[warn] Falling back to default attention implementation.")
|
| 587 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 588 |
+
str(base_dir),
|
| 589 |
+
device_map=device_map,
|
| 590 |
+
trust_remote_code=True,
|
| 591 |
+
low_cpu_mem_usage=True,
|
| 592 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 593 |
+
quantization_config=quant_cfg,
|
| 594 |
+
)
|
| 595 |
+
else:
|
| 596 |
+
raise e
|
| 597 |
+
except Exception as e:
|
| 598 |
+
print(f"[error] Failed to load model: {e}")
|
| 599 |
+
raise e
|
| 600 |
+
|
| 601 |
+
# Ensure all parameters are off meta device
|
| 602 |
+
print("[info] Ensuring all parameters are materialized...")
|
| 603 |
+
meta_params = []
|
| 604 |
+
for name, param in model.named_parameters():
|
| 605 |
+
if param.device.type == 'meta':
|
| 606 |
+
meta_params.append(name)
|
| 607 |
+
|
| 608 |
+
if meta_params:
|
| 609 |
+
print(f"[warn] Found {len(meta_params)} parameters on meta device")
|
| 610 |
+
# For multimodal models, freeze vision components if doing text-only training
|
| 611 |
+
if hasattr(model, 'vision_tower'):
|
| 612 |
+
print("[info] Freezing vision tower for text-only training")
|
| 613 |
+
for param in model.vision_tower.parameters():
|
| 614 |
+
param.requires_grad = False
|
| 615 |
+
|
| 616 |
+
return model, tokenizer
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def apply_peft(cfg: Dict[str, Any], model):
|
| 620 |
+
peft_cfg = cfg["peft"]
|
| 621 |
+
model_cfg = cfg["model"]
|
| 622 |
+
tr_cfg = cfg["train"]
|
| 623 |
+
|
| 624 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 625 |
+
return model, None
|
| 626 |
+
|
| 627 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 628 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 629 |
+
|
| 630 |
+
# For multimodal models, ensure vision tower doesn't use gradient checkpointing
|
| 631 |
+
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
|
| 632 |
+
if hasattr(model, 'vision_tower'):
|
| 633 |
+
print("[info] Disabling gradient checkpointing for vision tower")
|
| 634 |
+
# Only enable gradient checkpointing on language model
|
| 635 |
+
if hasattr(model, 'language_model'):
|
| 636 |
+
model.language_model.gradient_checkpointing_enable()
|
| 637 |
+
elif hasattr(model, 'lm_head'):
|
| 638 |
+
model.gradient_checkpointing_enable()
|
| 639 |
+
else:
|
| 640 |
+
model.gradient_checkpointing_enable()
|
| 641 |
+
|
| 642 |
+
if hasattr(model, "config"):
|
| 643 |
+
model.config.use_cache = False
|
| 644 |
+
|
| 645 |
+
if use_4bit:
|
| 646 |
+
model = prepare_model_for_kbit_training(
|
| 647 |
+
model,
|
| 648 |
+
use_gradient_checkpointing=gradient_checkpointing,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 652 |
+
if target_modules == "auto":
|
| 653 |
+
target_modules = _infer_target_modules(model)
|
| 654 |
+
|
| 655 |
+
# For multimodal models, ensure we only target language model modules
|
| 656 |
+
if hasattr(model, 'vision_tower') and isinstance(target_modules, list):
|
| 657 |
+
print(f"[info] Filtering target modules to exclude vision tower")
|
| 658 |
+
# Filter out any vision tower modules
|
| 659 |
+
target_modules = [m for m in target_modules if 'vision' not in m.lower()]
|
| 660 |
+
print(f"[info] LoRA target modules: {target_modules}")
|
| 661 |
+
|
| 662 |
+
lora_config = LoraConfig(
|
| 663 |
+
r=int(peft_cfg.get("r", 16)),
|
| 664 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 665 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 666 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 667 |
+
task_type="CAUSAL_LM",
|
| 668 |
+
target_modules=target_modules,
|
| 669 |
+
modules_to_save=None, # Don't update any additional modules
|
| 670 |
+
)
|
| 671 |
+
model = get_peft_model(model, lora_config)
|
| 672 |
+
return model, lora_config
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
# --------------------------
|
| 676 |
+
# Merge Logic
|
| 677 |
+
# --------------------------
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def merge_adapter(
|
| 681 |
+
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
|
| 682 |
+
):
|
| 683 |
+
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
|
| 684 |
+
|
| 685 |
+
model_cfg = cfg["model"]
|
| 686 |
+
merge_cfg = cfg.get("merge", {})
|
| 687 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 688 |
+
|
| 689 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 690 |
+
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
|
| 691 |
+
|
| 692 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 693 |
+
str(base_dir),
|
| 694 |
+
torch_dtype=merged_dtype,
|
| 695 |
+
device_map="cpu",
|
| 696 |
+
low_cpu_mem_usage=True,
|
| 697 |
+
trust_remote_code=trust_remote_code,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
merged = PeftModel.from_pretrained(base, str(adapter_dir))
|
| 701 |
+
merged = merged.merge_and_unload()
|
| 702 |
+
|
| 703 |
+
_ensure_dir(final_dir)
|
| 704 |
+
merged.save_pretrained(
|
| 705 |
+
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
tok = AutoTokenizer.from_pretrained(
|
| 709 |
+
str(base_dir), trust_remote_code=trust_remote_code
|
| 710 |
+
)
|
| 711 |
+
if tok.pad_token is None:
|
| 712 |
+
tok.pad_token = tok.eos_token
|
| 713 |
+
tok.save_pretrained(str(final_dir))
|
| 714 |
+
|
| 715 |
+
print("--- Merge complete ---")
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# --------------------------
|
| 719 |
+
# Main
|
| 720 |
+
# --------------------------
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def main():
|
| 724 |
+
ap = argparse.ArgumentParser()
|
| 725 |
+
ap.add_argument("--config", required=True, help="Path to YAML config")
|
| 726 |
+
ap.add_argument(
|
| 727 |
+
"--merge-only", action="store_true", help="Skip training, just merge adapter"
|
| 728 |
+
)
|
| 729 |
+
args = ap.parse_args()
|
| 730 |
+
|
| 731 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 732 |
+
cfg = yaml.safe_load(f)
|
| 733 |
+
|
| 734 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 735 |
+
_ensure_dir(run_dir / "logs")
|
| 736 |
+
|
| 737 |
+
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
|
| 738 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 739 |
+
|
| 740 |
+
model_cfg = cfg["model"]
|
| 741 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 742 |
+
repo_path = Path(repo_id)
|
| 743 |
+
|
| 744 |
+
# ✅ Local model path -> load directly; no download
|
| 745 |
+
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
|
| 746 |
+
base_dir = repo_path
|
| 747 |
+
print(f"Using local model at: {base_dir}")
|
| 748 |
+
elif repo_path.exists() and repo_path.is_dir():
|
| 749 |
+
raise ValueError(
|
| 750 |
+
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
|
| 751 |
+
)
|
| 752 |
+
else:
|
| 753 |
+
# HF repo_id -> download into run_dir/base_local_dir
|
| 754 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 755 |
+
if not _looks_like_model_dir(base_dir):
|
| 756 |
+
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
|
| 757 |
+
snapshot_download(
|
| 758 |
+
repo_id=repo_id,
|
| 759 |
+
revision=model_cfg.get("revision", None),
|
| 760 |
+
local_dir=str(base_dir),
|
| 761 |
+
local_dir_use_symlinks=False,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
|
| 765 |
+
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
|
| 766 |
+
|
| 767 |
+
merge_cfg = cfg.get("merge", {}) or {}
|
| 768 |
+
if merge_cfg.get("output_dir"):
|
| 769 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 770 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 771 |
+
else:
|
| 772 |
+
final_dir = run_dir / "final_model"
|
| 773 |
+
|
| 774 |
+
# Merge-only
|
| 775 |
+
if args.merge_only:
|
| 776 |
+
if not _looks_like_model_dir(best_adapter_dir):
|
| 777 |
+
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
|
| 778 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 779 |
+
return
|
| 780 |
+
|
| 781 |
+
# Initialize Wandb
|
| 782 |
+
wandb_run = setup_wandb(cfg, run_dir)
|
| 783 |
+
|
| 784 |
+
# Training
|
| 785 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 786 |
+
|
| 787 |
+
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
|
| 788 |
+
model, _ = apply_peft(cfg, model)
|
| 789 |
+
|
| 790 |
+
train_ds, eval_ds = build_datasets(cfg, tokenizer)
|
| 791 |
+
|
| 792 |
+
tr_cfg = cfg["train"]
|
| 793 |
+
|
| 794 |
+
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 795 |
+
use_fp16 = dtype == torch.float16
|
| 796 |
+
use_bf16 = dtype == torch.bfloat16
|
| 797 |
+
|
| 798 |
+
max_steps = int(tr_cfg.get("max_steps", 0))
|
| 799 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 800 |
+
|
| 801 |
+
# --- Dynamic evaluation strategy parameter handling ---
|
| 802 |
+
ta_params = inspect.signature(TrainingArguments.__init__).parameters
|
| 803 |
+
eval_key = (
|
| 804 |
+
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# Setup reporting based on wandb availability
|
| 808 |
+
report_to = []
|
| 809 |
+
if wandb_run is not None:
|
| 810 |
+
report_to.append("wandb")
|
| 811 |
+
|
| 812 |
+
ta_kwargs = dict(
|
| 813 |
+
output_dir=str(ckpt_dir),
|
| 814 |
+
max_steps=max_steps if max_steps > 0 else -1,
|
| 815 |
+
num_train_epochs=num_train_epochs,
|
| 816 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
|
| 817 |
+
per_device_eval_batch_size=int(
|
| 818 |
+
tr_cfg.get(
|
| 819 |
+
"per_device_eval_batch_size",
|
| 820 |
+
tr_cfg.get("per_device_train_batch_size", 1),
|
| 821 |
+
)
|
| 822 |
+
),
|
| 823 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
|
| 824 |
+
learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
|
| 825 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
|
| 826 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 827 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
|
| 828 |
+
optim=str(
|
| 829 |
+
tr_cfg.get(
|
| 830 |
+
"optim",
|
| 831 |
+
(
|
| 832 |
+
"paged_adamw_8bit"
|
| 833 |
+
if bool(model_cfg.get("use_4bit", False))
|
| 834 |
+
else "adamw_torch"
|
| 835 |
+
),
|
| 836 |
+
)
|
| 837 |
+
),
|
| 838 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
|
| 839 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 840 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 841 |
+
save_steps=int(tr_cfg.get("save_steps", 200)),
|
| 842 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 843 |
+
eval_steps=int(tr_cfg.get("eval_steps", 200)),
|
| 844 |
+
load_best_model_at_end=(
|
| 845 |
+
bool(tr_cfg.get("load_best_model_at_end", True))
|
| 846 |
+
if eval_ds is not None
|
| 847 |
+
else False
|
| 848 |
+
),
|
| 849 |
+
metric_for_best_model="eval_loss",
|
| 850 |
+
greater_is_better=False,
|
| 851 |
+
fp16=use_fp16,
|
| 852 |
+
bf16=use_bf16,
|
| 853 |
+
report_to=report_to,
|
| 854 |
+
remove_unused_columns=False,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# Set the correct argument name for this transformers version
|
| 858 |
+
ta_kwargs[eval_key] = str(
|
| 859 |
+
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
training_args = TrainingArguments(**ta_kwargs)
|
| 863 |
+
|
| 864 |
+
# Setup callbacks
|
| 865 |
+
callbacks = [JsonlLoggerCallback(run_dir)]
|
| 866 |
+
|
| 867 |
+
# Add early stopping callback if enabled
|
| 868 |
+
early_stopping_cfg = tr_cfg.get("early_stopping", {})
|
| 869 |
+
if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
|
| 870 |
+
early_stopping_callback = EarlyStoppingCallback(
|
| 871 |
+
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
|
| 872 |
+
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
|
| 873 |
+
)
|
| 874 |
+
callbacks.append(early_stopping_callback)
|
| 875 |
+
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
|
| 876 |
+
f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
|
| 877 |
+
|
| 878 |
+
trainer = Trainer(
|
| 879 |
+
model=model,
|
| 880 |
+
args=training_args,
|
| 881 |
+
train_dataset=train_ds,
|
| 882 |
+
eval_dataset=eval_ds,
|
| 883 |
+
data_collator=default_data_collator,
|
| 884 |
+
callbacks=callbacks,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
# Resume
|
| 888 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 889 |
+
if resume_from == "auto":
|
| 890 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 891 |
+
resume_from = last if last else None
|
| 892 |
+
if resume_from:
|
| 893 |
+
print(f"Resuming from {resume_from}")
|
| 894 |
+
|
| 895 |
+
print("Starting instruction fine-tuning...")
|
| 896 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 897 |
+
|
| 898 |
+
trainer.save_model(str(best_adapter_dir))
|
| 899 |
+
print(f"Saved best adapter -> {best_adapter_dir}")
|
| 900 |
+
|
| 901 |
+
if eval_ds is not None:
|
| 902 |
+
metrics = trainer.evaluate()
|
| 903 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 904 |
+
metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
|
| 905 |
+
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
|
| 906 |
+
json.dump(metrics, f, indent=2)
|
| 907 |
+
print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
|
| 908 |
+
|
| 909 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 910 |
+
del trainer, model
|
| 911 |
+
torch.cuda.empty_cache()
|
| 912 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 913 |
+
else:
|
| 914 |
+
print("Merge disabled. Run with --merge-only later if needed.")
|
| 915 |
+
|
| 916 |
+
# Finish Wandb run
|
| 917 |
+
finish_wandb()
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
if __name__ == "__main__":
|
| 921 |
+
main()
|
trainer-kit/documentation.md
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CPT Training Different Modules Guide
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
By default, the CPT (Continual Pre-Training) configuration in `/workspace/Trainer-kit/CPT/config.yaml` trains only **attention projection layers** using LoRA adapters. This guide explains how to modify the configuration to train other modules.
|
| 6 |
+
|
| 7 |
+
## Current Default Configuration
|
| 8 |
+
|
| 9 |
+
```yaml
|
| 10 |
+
peft:
|
| 11 |
+
enabled: true
|
| 12 |
+
target_modules: "auto"
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
When `target_modules: "auto"` is set, the script automatically detects and trains these attention layers:
|
| 16 |
+
- `q_proj` - Query projection
|
| 17 |
+
- `k_proj` - Key projection
|
| 18 |
+
- `v_proj` - Value projection
|
| 19 |
+
- `o_proj` - Output projection
|
| 20 |
+
|
| 21 |
+
## How to Train Other Modules
|
| 22 |
+
|
| 23 |
+
### Method 1: Explicit Target Modules
|
| 24 |
+
|
| 25 |
+
Replace `"auto"` with a list of specific module names you want to train:
|
| 26 |
+
|
| 27 |
+
```yaml
|
| 28 |
+
peft:
|
| 29 |
+
enabled: true
|
| 30 |
+
target_modules:
|
| 31 |
+
- "q_proj"
|
| 32 |
+
- "k_proj"
|
| 33 |
+
- "v_proj"
|
| 34 |
+
- "o_proj"
|
| 35 |
+
- "mlp.down_proj" # Add MLP down projection
|
| 36 |
+
- "mlp.gate_proj" # Add MLP gate projection
|
| 37 |
+
- "mlp.up_proj" # Add MLP up projection
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Method 2: Custom Module Lists
|
| 41 |
+
|
| 42 |
+
For different model architectures, here are common modules you can train:
|
| 43 |
+
|
| 44 |
+
#### LLaMA/Llama-style Models
|
| 45 |
+
```yaml
|
| 46 |
+
peft:
|
| 47 |
+
enabled: true
|
| 48 |
+
target_modules:
|
| 49 |
+
- "q_proj"
|
| 50 |
+
- "k_proj"
|
| 51 |
+
- "v_proj"
|
| 52 |
+
- "o_proj"
|
| 53 |
+
- "mlp.gate_proj"
|
| 54 |
+
- "mlp.up_proj"
|
| 55 |
+
- "mlp.down_proj"
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
#### Qwen-style Models
|
| 59 |
+
```yaml
|
| 60 |
+
peft:
|
| 61 |
+
enabled: true
|
| 62 |
+
target_modules:
|
| 63 |
+
- "q_proj"
|
| 64 |
+
- "k_proj"
|
| 65 |
+
- "v_proj"
|
| 66 |
+
- "o_proj"
|
| 67 |
+
- "mlp.gate_proj"
|
| 68 |
+
- "mlp.up_proj"
|
| 69 |
+
- "mlp.down_proj"
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
#### Mixtral/Gemma-style Models
|
| 73 |
+
```yaml
|
| 74 |
+
peft:
|
| 75 |
+
enabled: true
|
| 76 |
+
target_modules:
|
| 77 |
+
- "q_proj"
|
| 78 |
+
- "k_proj"
|
| 79 |
+
- "v_proj"
|
| 80 |
+
- "o_proj"
|
| 81 |
+
- "mlp.experts.*.w1" # Expert layer 1
|
| 82 |
+
- "mlp.experts.*.w2" # Expert layer 2
|
| 83 |
+
- "mlp.experts.*.w3" # Expert layer 3
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## Module Types You Can Train
|
| 87 |
+
|
| 88 |
+
### 1. Attention Layers
|
| 89 |
+
- `q_proj` - Query projections
|
| 90 |
+
- `k_proj` - Key projections
|
| 91 |
+
- `v_proj` - Value projections
|
| 92 |
+
- `o_proj` - Output projections
|
| 93 |
+
- `qkv_proj` - Combined QKV (in some models)
|
| 94 |
+
- `c_attn` - Attention in older models
|
| 95 |
+
|
| 96 |
+
### 2. MLP/Feed-Forward Layers
|
| 97 |
+
- `mlp.gate_proj` - Gate projection
|
| 98 |
+
- `mlp.up_proj` - Up projection
|
| 99 |
+
- `mlp.down_proj` - Down projection
|
| 100 |
+
- `mlp.fc1` - First layer
|
| 101 |
+
- `mlp.fc2` - Second layer
|
| 102 |
+
- `w1`, `w2`, `w3` - Alternative naming
|
| 103 |
+
|
| 104 |
+
### 3. Embedding Layers
|
| 105 |
+
```yaml
|
| 106 |
+
peft:
|
| 107 |
+
enabled: true
|
| 108 |
+
target_modules:
|
| 109 |
+
- "model.embed_tokens" # Token embeddings
|
| 110 |
+
- "lm_head" # Language model head
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### 4. Normalization Layers
|
| 114 |
+
```yaml
|
| 115 |
+
peft:
|
| 116 |
+
enabled: true
|
| 117 |
+
target_modules:
|
| 118 |
+
- "input_layernorm" # Input normalization
|
| 119 |
+
- "post_attention_layernorm" # Post-attention norm
|
| 120 |
+
- "final_layernorm" # Final normalization
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### 5. MoE (Mixture of Experts) Layers
|
| 124 |
+
```yaml
|
| 125 |
+
peft:
|
| 126 |
+
enabled: true
|
| 127 |
+
target_modules:
|
| 128 |
+
- "mlp.experts.*.w1" # Expert layer 1
|
| 129 |
+
- "mlp.experts.*.w2" # Expert layer 2
|
| 130 |
+
- "mlp.experts.*.w3" # Expert layer 3
|
| 131 |
+
- "mlp.gate" # Expert routing gate
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## Advanced Configuration Examples
|
| 135 |
+
|
| 136 |
+
### Train Multiple Layer Types
|
| 137 |
+
```yaml
|
| 138 |
+
peft:
|
| 139 |
+
enabled: true
|
| 140 |
+
target_modules:
|
| 141 |
+
- "q_proj"
|
| 142 |
+
- "k_proj"
|
| 143 |
+
- "v_proj"
|
| 144 |
+
- "o_proj"
|
| 145 |
+
- "mlp.gate_proj"
|
| 146 |
+
- "mlp.up_proj"
|
| 147 |
+
- "mlp.down_proj"
|
| 148 |
+
- "input_layernorm"
|
| 149 |
+
- "post_attention_layernorm"
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
### Conservative Approach (Only MLPs)
|
| 153 |
+
```yaml
|
| 154 |
+
peft:
|
| 155 |
+
enabled: true
|
| 156 |
+
target_modules:
|
| 157 |
+
- "mlp.gate_proj"
|
| 158 |
+
- "mlp.up_proj"
|
| 159 |
+
- "mlp.down_proj"
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
### Comprehensive Approach (All Main Layers)
|
| 163 |
+
```yaml
|
| 164 |
+
peft:
|
| 165 |
+
enabled: true
|
| 166 |
+
target_modules:
|
| 167 |
+
- "q_proj"
|
| 168 |
+
- "k_proj"
|
| 169 |
+
- "v_proj"
|
| 170 |
+
- "o_proj"
|
| 171 |
+
- "mlp.gate_proj"
|
| 172 |
+
- "mlp.up_proj"
|
| 173 |
+
- "mlp.down_proj"
|
| 174 |
+
- "input_layernorm"
|
| 175 |
+
- "post_attention_layernorm"
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
## How to Find Module Names for Your Model
|
| 179 |
+
|
| 180 |
+
### Method 1: Automatic Detection
|
| 181 |
+
Run the script once with `target_modules: "auto"` - it will log which modules it found:
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
Using auto-inferred target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### Method 2: Manual Inspection
|
| 188 |
+
Inspect your model structure:
|
| 189 |
+
|
| 190 |
+
```python
|
| 191 |
+
import torch
|
| 192 |
+
from transformers import AutoModel
|
| 193 |
+
|
| 194 |
+
model = AutoModel.from_pretrained("/workspace/Models/YourModel")
|
| 195 |
+
|
| 196 |
+
# Print all module names
|
| 197 |
+
for name, module in model.named_modules():
|
| 198 |
+
print(name)
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### Method 3: Use PEFT's Built-in Function
|
| 202 |
+
The script includes `_infer_target_modules()` function that can help identify available modules.
|
| 203 |
+
|
| 204 |
+
## Considerations
|
| 205 |
+
|
| 206 |
+
### 1. Memory Usage
|
| 207 |
+
- **More modules = More memory**: Training additional layers requires more GPU memory
|
| 208 |
+
- **Monitor VRAM usage**: Use `nvidia-smi` to monitor memory consumption
|
| 209 |
+
- **Adjust batch size**: You may need to reduce `per_device_train_batch_size`
|
| 210 |
+
|
| 211 |
+
### 2. Training Time
|
| 212 |
+
- **More modules = Longer training**: Each additional layer increases computation time
|
| 213 |
+
- **Learning rate adjustments**: You might need to reduce `learning_rate` when training more layers
|
| 214 |
+
|
| 215 |
+
### 3. Performance Trade-offs
|
| 216 |
+
- **Attention only**: Fast training, good for language understanding
|
| 217 |
+
- **MLP only**: Fast training, good for knowledge storage
|
| 218 |
+
- **Both attention + MLP**: Slower but potentially better performance
|
| 219 |
+
- **All layers**: Slowest but most comprehensive adaptation
|
| 220 |
+
|
| 221 |
+
### 4. Model Architecture Differences
|
| 222 |
+
Different model families use different module naming conventions:
|
| 223 |
+
- **LLaMA**: `mlp.gate_proj`, `mlp.up_proj`, `mlp.down_proj`
|
| 224 |
+
- **Qwen**: `mlp.gate_proj`, `mlp.up_proj`, `mlp.down_proj`
|
| 225 |
+
- **Gemma**: `mlp.gate_proj`, `mlp.up_proj`, `mlp.down_proj`
|
| 226 |
+
- **Mixtral**: `mlp.experts.*.w1`, etc.
|
| 227 |
+
|
| 228 |
+
## Best Practices
|
| 229 |
+
|
| 230 |
+
### 1. Start Conservative
|
| 231 |
+
Begin with just attention layers, then gradually add more modules if needed:
|
| 232 |
+
```yaml
|
| 233 |
+
# Phase 1: Start here
|
| 234 |
+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 235 |
+
|
| 236 |
+
# Phase 2: Add MLPs
|
| 237 |
+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "mlp.down_proj"]
|
| 238 |
+
|
| 239 |
+
# Phase 3: Add more if needed
|
| 240 |
+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
### 2. Monitor Overfitting
|
| 244 |
+
- Use evaluation split to monitor performance
|
| 245 |
+
- Adjust `learning_rate` if overfitting occurs
|
| 246 |
+
- Consider `lora_dropout` to reduce overfitting
|
| 247 |
+
|
| 248 |
+
### 3. Resource Management
|
| 249 |
+
- Start with small LoRA rank (`r: 16`) if training many modules
|
| 250 |
+
- Increase `gradient_accumulation_steps` if reducing batch size
|
| 251 |
+
- Monitor GPU memory usage throughout training
|
| 252 |
+
|
| 253 |
+
### 4. Model-Specific Tuning
|
| 254 |
+
Different models may benefit from different module combinations:
|
| 255 |
+
- **Code models**: Focus on attention + MLP layers
|
| 256 |
+
- **Chat models**: Attention layers are most important
|
| 257 |
+
- **Reasoning models**: All layers might be beneficial
|
| 258 |
+
|
| 259 |
+
## Example: Training Custom Modules
|
| 260 |
+
|
| 261 |
+
### Complete Configuration Example
|
| 262 |
+
```yaml
|
| 263 |
+
model:
|
| 264 |
+
repo_id: "/workspace/Models/Devstral-Small-2-24B-Instruct-2512"
|
| 265 |
+
torch_dtype: "bfloat16"
|
| 266 |
+
|
| 267 |
+
peft:
|
| 268 |
+
enabled: true
|
| 269 |
+
r: 64
|
| 270 |
+
lora_alpha: 128
|
| 271 |
+
lora_dropout: 0.05
|
| 272 |
+
bias: "none"
|
| 273 |
+
target_modules:
|
| 274 |
+
- "q_proj"
|
| 275 |
+
- "k_proj"
|
| 276 |
+
- "v_proj"
|
| 277 |
+
- "o_proj"
|
| 278 |
+
- "mlp.gate_proj"
|
| 279 |
+
- "mlp.up_proj"
|
| 280 |
+
- "mlp.down_proj"
|
| 281 |
+
- "input_layernorm"
|
| 282 |
+
|
| 283 |
+
train:
|
| 284 |
+
num_train_epochs: 2
|
| 285 |
+
learning_rate: 1e-5 # Reduced due to more modules
|
| 286 |
+
per_device_train_batch_size: 1
|
| 287 |
+
gradient_accumulation_steps: 16
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
This configuration will train:
|
| 291 |
+
- All attention projection layers
|
| 292 |
+
- All MLP projection layers
|
| 293 |
+
- Input normalization layers
|
| 294 |
+
- Using a reduced learning rate to accommodate the additional trainable parameters.
|
| 295 |
+
|
| 296 |
+
Remember to always test with a small number of steps first to ensure your configuration works correctly before running full training.
|