Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- DPO-14b/README.md +181 -0
- DPO-14b/apply_critical_fixes.py +206 -0
- DPO-14b/config_dpo.yaml +141 -0
- DPO-14b/create_synthetic_pairs.py +136 -0
- DPO-14b/dpo_dataset.jsonl +3 -0
- DPO-14b/dpo_pairs_generated.jsonl +3 -0
- DPO-14b/f1_score_utils.py +283 -0
- DPO-14b/prepare_data.py +343 -0
- DPO-14b/requirements.txt +29 -0
- DPO-14b/run_dpo.py +953 -0
- DPO-14b/run_dpo.py.backup +923 -0
- DPO-14b/run_dpo_enhanced.py +310 -0
- DPO-14b/test_fixes.py +108 -0
.gitattributes
CHANGED
|
@@ -38,3 +38,4 @@ dpo_run_14B/checkpoint-100/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
| 38 |
dpo_run_14B/wandb/run-20251226_152332-r9hfat2g/run-r9hfat2g.wandb filter=lfs diff=lfs merge=lfs -text
|
| 39 |
dpo_run_14B/wandb/run-20251226_152936-r1nptay8/run-r1nptay8.wandb filter=lfs diff=lfs merge=lfs -text
|
| 40 |
dpo_run_14B/wandb/run-20251226_155650-wbzoafvt/run-wbzoafvt.wandb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 38 |
dpo_run_14B/wandb/run-20251226_152332-r9hfat2g/run-r9hfat2g.wandb filter=lfs diff=lfs merge=lfs -text
|
| 39 |
dpo_run_14B/wandb/run-20251226_152936-r1nptay8/run-r1nptay8.wandb filter=lfs diff=lfs merge=lfs -text
|
| 40 |
dpo_run_14B/wandb/run-20251226_155650-wbzoafvt/run-wbzoafvt.wandb filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
DPO-14b/dpo_pairs_generated.jsonl filter=lfs diff=lfs merge=lfs -text
|
DPO-14b/README.md
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DPO Training for Code Analysis
|
| 2 |
+
|
| 3 |
+
This folder contains a Direct Preference Optimization (DPO) trainer for fine-tuning models on code analysis tasks with preference pairs.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
DPO training uses preference pairs (chosen/rejected responses) to optimize the model to prefer better outputs over worse ones. This is particularly useful for tasks where we have multiple responses with different quality levels.
|
| 8 |
+
|
| 9 |
+
## Files
|
| 10 |
+
|
| 11 |
+
- `run_dpo.py` - Main DPO training script
|
| 12 |
+
- `config_dpo.yaml` - Configuration file for DPO training
|
| 13 |
+
- `f1_score_utils.py` - Utilities for computing F1 scores and creating preference pairs
|
| 14 |
+
- `requirements.txt` - Python dependencies
|
| 15 |
+
- `dpo_dataset.jsonl` - Sample DPO dataset
|
| 16 |
+
|
| 17 |
+
## Data Format
|
| 18 |
+
|
| 19 |
+
DPO requires data in the following format:
|
| 20 |
+
|
| 21 |
+
```jsonl
|
| 22 |
+
{
|
| 23 |
+
"prompt": "##TASK\n<task description>",
|
| 24 |
+
"chosen": "<better response with correct file selections>",
|
| 25 |
+
"rejected": "<worse response with incorrect file selections>",
|
| 26 |
+
"chosen_f1": 1.0,
|
| 27 |
+
"rejected_f1": 0.5
|
| 28 |
+
}
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Creating DPO Data from SFT Data
|
| 32 |
+
|
| 33 |
+
You can use the F1 score utility to create DPO pairs from multiple model generations:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from f1_score_utils import create_dpo_pairs_from_generations
|
| 37 |
+
|
| 38 |
+
prompt = "##TASK\nAdd webhook support..."
|
| 39 |
+
generations = [output1, output2, output3, output4] # Multiple model outputs
|
| 40 |
+
ground_truth = "##OUTPUT\n...\n##SELECT\n..."
|
| 41 |
+
|
| 42 |
+
pairs = create_dpo_pairs_from_generations(
|
| 43 |
+
prompt, generations, ground_truth, min_f1_difference=0.1
|
| 44 |
+
)
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## F1 Score Ranking
|
| 48 |
+
|
| 49 |
+
The F1 score is computed at the **file level**:
|
| 50 |
+
- **Precision**: Correct files / Total predicted files
|
| 51 |
+
- **Recall**: Correct files / Total ground truth files
|
| 52 |
+
- **F1**: Harmonic mean of precision and recall
|
| 53 |
+
|
| 54 |
+
Files are extracted from the `##SELECT` section:
|
| 55 |
+
```
|
| 56 |
+
##SELECT
|
| 57 |
+
crates/router/src/webhooks.rs::process_webhook
|
| 58 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 59 |
+
<EOS>
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Installation
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
pip install -r requirements.txt
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Usage
|
| 69 |
+
|
| 70 |
+
### 1. Prepare DPO Dataset
|
| 71 |
+
|
| 72 |
+
You need to generate multiple outputs for each prompt and rank them by F1 score:
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
from f1_score_utils import compute_file_level_f1, rank_outputs_by_f1
|
| 76 |
+
|
| 77 |
+
# Rank outputs
|
| 78 |
+
ranked = rank_outputs_by_f1(outputs, ground_truth)
|
| 79 |
+
for output, f1, metrics in ranked:
|
| 80 |
+
print(f"F1: {f1:.3f} - {metrics['true_positives']} correct files")
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### 2. Configure Training
|
| 84 |
+
|
| 85 |
+
Edit `config_dpo.yaml`:
|
| 86 |
+
- Set `model.repo_id` to your SFT model path
|
| 87 |
+
- Adjust `dpo.beta` (temperature parameter, default 0.1)
|
| 88 |
+
- Set `dpo.loss_type` (sigmoid, hinge, ipo, kto)
|
| 89 |
+
- Configure training hyperparameters
|
| 90 |
+
|
| 91 |
+
### 3. Run Training
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
python run_dpo.py --config config_dpo.yaml
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### 4. Merge Adapter (Optional)
|
| 98 |
+
|
| 99 |
+
If training is complete and you want to merge the adapter:
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
python run_dpo.py --config config_dpo.yaml --merge-only
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Configuration
|
| 106 |
+
|
| 107 |
+
### DPO Parameters
|
| 108 |
+
|
| 109 |
+
- `beta`: Temperature for DPO loss (higher = less aggressive preference learning)
|
| 110 |
+
- `label_smoothing`: Smoothing factor for labels
|
| 111 |
+
- `loss_type`: Type of loss function
|
| 112 |
+
- `sigmoid`: Standard DPO loss (default)
|
| 113 |
+
- `hinge`: Margin-based loss
|
| 114 |
+
- `ipo`: Identity Policy Optimization
|
| 115 |
+
- `kto`: Kahneman-Tversky Optimization
|
| 116 |
+
- `use_reference_model`: Whether to use a frozen reference model
|
| 117 |
+
|
| 118 |
+
### Training Tips
|
| 119 |
+
|
| 120 |
+
1. **Learning Rate**: Use lower LR than SFT (e.g., 5e-5 vs 2e-4)
|
| 121 |
+
2. **Beta**: Start with 0.1, increase for less aggressive learning
|
| 122 |
+
3. **Batch Size**: Larger batches are more stable
|
| 123 |
+
4. **Data Quality**: Ensure significant F1 difference between chosen/rejected (β₯0.1)
|
| 124 |
+
|
| 125 |
+
## Output
|
| 126 |
+
|
| 127 |
+
Training outputs:
|
| 128 |
+
- `runs/dpo_run_14b_v1/checkpoints/` - Training checkpoints
|
| 129 |
+
- `runs/dpo_run_14b_v1/best_adapter/` - Best adapter weights
|
| 130 |
+
- `runs/dpo_run_14b_v1/merged_14b_dpo_lora/` - Merged model
|
| 131 |
+
- `runs/dpo_run_14b_v1/logs/` - Training logs (JSONL format)
|
| 132 |
+
|
| 133 |
+
## WandB Integration
|
| 134 |
+
|
| 135 |
+
Enable experiment tracking in `config_dpo.yaml`:
|
| 136 |
+
|
| 137 |
+
```yaml
|
| 138 |
+
wandb:
|
| 139 |
+
enabled: true
|
| 140 |
+
project: "dpo-training"
|
| 141 |
+
tags: ["dpo-lora", "preference-optimization"]
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## Example: Generate DPO Data
|
| 145 |
+
|
| 146 |
+
```python
|
| 147 |
+
import json
|
| 148 |
+
from f1_score_utils import compute_file_level_f1, create_dpo_pairs_from_generations
|
| 149 |
+
|
| 150 |
+
# Load SFT data
|
| 151 |
+
with open("instruct_data.jsonl") as f:
|
| 152 |
+
for line in f:
|
| 153 |
+
data = json.loads(line)
|
| 154 |
+
prompt = data["input"]
|
| 155 |
+
ground_truth = data["output"]
|
| 156 |
+
|
| 157 |
+
# Generate multiple outputs with your model
|
| 158 |
+
generations = generate_multiple_outputs(prompt, num_samples=4)
|
| 159 |
+
|
| 160 |
+
# Create preference pairs
|
| 161 |
+
pairs = create_dpo_pairs_from_generations(
|
| 162 |
+
prompt, generations, ground_truth, min_f1_difference=0.1
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Save pairs
|
| 166 |
+
with open("dpo_dataset.jsonl", "a") as out:
|
| 167 |
+
for pair in pairs:
|
| 168 |
+
out.write(json.dumps(pair) + "\n")
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
## Troubleshooting
|
| 172 |
+
|
| 173 |
+
1. **OOM Errors**: Reduce batch size or enable gradient checkpointing
|
| 174 |
+
2. **No Improvement**: Check F1 score differences in data, increase beta
|
| 175 |
+
3. **Unstable Training**: Lower learning rate, increase warmup ratio
|
| 176 |
+
4. **Reference Model Issues**: Set `use_reference_model: false` to use implicit reference
|
| 177 |
+
|
| 178 |
+
## References
|
| 179 |
+
|
| 180 |
+
- DPO Paper: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
|
| 181 |
+
- TRL Library: [HuggingFace TRL](https://github.com/huggingface/trl)
|
DPO-14b/apply_critical_fixes.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick fix script to apply critical improvements to run_dpo.py
|
| 4 |
+
Run this to automatically patch the DPO trainer with all critical fixes.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import shutil
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
def backup_file(filepath):
|
| 12 |
+
"""Create backup of original file"""
|
| 13 |
+
backup_path = Path(str(filepath) + '.backup')
|
| 14 |
+
shutil.copy2(filepath, backup_path)
|
| 15 |
+
print(f"β
Backup created: {backup_path}")
|
| 16 |
+
return backup_path
|
| 17 |
+
|
| 18 |
+
def apply_fixes(filepath='run_dpo.py'):
|
| 19 |
+
"""Apply all critical fixes to the DPO training script"""
|
| 20 |
+
|
| 21 |
+
filepath = Path(filepath)
|
| 22 |
+
if not filepath.exists():
|
| 23 |
+
print(f"β Error: {filepath} not found")
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
# Backup original
|
| 27 |
+
backup_file(filepath)
|
| 28 |
+
|
| 29 |
+
with open(filepath, 'r') as f:
|
| 30 |
+
content = f.read()
|
| 31 |
+
|
| 32 |
+
fixes_applied = []
|
| 33 |
+
|
| 34 |
+
# Fix 1: Add missing imports
|
| 35 |
+
if 'import gc' not in content:
|
| 36 |
+
content = content.replace(
|
| 37 |
+
'import time\nfrom pathlib',
|
| 38 |
+
'import gc\nimport time\nimport logging\nfrom pathlib'
|
| 39 |
+
)
|
| 40 |
+
fixes_applied.append("Added gc and logging imports")
|
| 41 |
+
|
| 42 |
+
# Fix 2: Add logging setup
|
| 43 |
+
if 'logging.basicConfig' not in content:
|
| 44 |
+
content = content.replace(
|
| 45 |
+
'wandb = None\n\n\n# --------------------------\n# Helpers',
|
| 46 |
+
'''wandb = None
|
| 47 |
+
|
| 48 |
+
# Setup logging
|
| 49 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# --------------------------
|
| 54 |
+
# Custom Exceptions
|
| 55 |
+
# --------------------------
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DataFormattingError(Exception):
|
| 59 |
+
"""Exception raised for errors in data formatting."""
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class DataValidationError(Exception):
|
| 64 |
+
"""Exception raised for errors in data validation."""
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# --------------------------
|
| 69 |
+
# Helpers'''
|
| 70 |
+
)
|
| 71 |
+
fixes_applied.append("Added logging setup and custom exceptions")
|
| 72 |
+
|
| 73 |
+
# Fix 3: Add validation function
|
| 74 |
+
if 'def validate_dpo_data' not in content:
|
| 75 |
+
validation_func = '''
|
| 76 |
+
|
| 77 |
+
def validate_dpo_data(dataset, stage: str = "train") -> None:
|
| 78 |
+
"""
|
| 79 |
+
Validate DPO dataset has all required fields and proper structure.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
dataset: Dataset to validate
|
| 83 |
+
stage: Training stage ("train" or "eval")
|
| 84 |
+
|
| 85 |
+
Raises:
|
| 86 |
+
DataValidationError if validation fails
|
| 87 |
+
"""
|
| 88 |
+
required_fields = ["prompt", "chosen", "rejected"]
|
| 89 |
+
|
| 90 |
+
# Check required fields exist
|
| 91 |
+
for field in required_fields:
|
| 92 |
+
if field not in dataset.column_names:
|
| 93 |
+
raise DataValidationError(
|
| 94 |
+
f"{stage} dataset missing required field: {field}. "
|
| 95 |
+
f"Available fields: {dataset.column_names}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Sample validation - check first example
|
| 99 |
+
if len(dataset) > 0:
|
| 100 |
+
sample = dataset[0]
|
| 101 |
+
for field in required_fields:
|
| 102 |
+
if not sample[field] or len(sample[field].strip()) == 0:
|
| 103 |
+
logger.warning(f"{stage} dataset has empty {field} in first example")
|
| 104 |
+
|
| 105 |
+
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples")
|
| 106 |
+
|
| 107 |
+
'''
|
| 108 |
+
# Insert before build_dpo_datasets
|
| 109 |
+
content = content.replace(
|
| 110 |
+
'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)',
|
| 111 |
+
validation_func + 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)'
|
| 112 |
+
)
|
| 113 |
+
fixes_applied.append("Added data validation function")
|
| 114 |
+
|
| 115 |
+
# Fix 4: Improve merge_adapter with memory cleanup
|
| 116 |
+
old_merge = ''' merged.save_pretrained(
|
| 117 |
+
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
tok = AutoTokenizer.from_pretrained('''
|
| 121 |
+
|
| 122 |
+
new_merge = ''' # Clean up base model to free memory
|
| 123 |
+
del base
|
| 124 |
+
gc.collect()
|
| 125 |
+
torch.cuda.empty_cache()
|
| 126 |
+
|
| 127 |
+
merged.save_pretrained(
|
| 128 |
+
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Clean up merged model
|
| 132 |
+
del merged
|
| 133 |
+
gc.collect()
|
| 134 |
+
torch.cuda.empty_cache()
|
| 135 |
+
|
| 136 |
+
tok = AutoTokenizer.from_pretrained('''
|
| 137 |
+
|
| 138 |
+
if old_merge in content and 'del base' not in content:
|
| 139 |
+
content = content.replace(old_merge, new_merge)
|
| 140 |
+
fixes_applied.append("Added memory cleanup in merge_adapter")
|
| 141 |
+
|
| 142 |
+
# Fix 5: Add TRL version check
|
| 143 |
+
if 'version.parse(trl.__version__)' not in content:
|
| 144 |
+
content = content.replace(
|
| 145 |
+
'from trl import DPOTrainer, DPOConfig',
|
| 146 |
+
'''from trl import DPOTrainer, DPOConfig
|
| 147 |
+
|
| 148 |
+
# Version check for TRL
|
| 149 |
+
try:
|
| 150 |
+
from packaging import version
|
| 151 |
+
import trl
|
| 152 |
+
if version.parse(trl.__version__) < version.parse("0.7.0"):
|
| 153 |
+
print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
|
| 154 |
+
except ImportError:
|
| 155 |
+
print("Warning: Could not verify TRL version")'''
|
| 156 |
+
)
|
| 157 |
+
fixes_applied.append("Added TRL version check")
|
| 158 |
+
|
| 159 |
+
# Fix 6: Replace some critical print statements with logger
|
| 160 |
+
content = content.replace('print(f"Using local model at:', 'logger.info(f"Using local model at:')
|
| 161 |
+
content = content.replace('print(f"Loading reference model', 'logger.info(f"Loading reference model')
|
| 162 |
+
content = content.replace('print(f"DPO Training with beta', 'logger.info(f"DPO Training with beta')
|
| 163 |
+
content = content.replace('print(f"Resuming from', 'logger.info(f"Resuming from')
|
| 164 |
+
content = content.replace('print("Starting DPO training', 'logger.info("Starting DPO training')
|
| 165 |
+
content = content.replace('print(f"Saved best adapter', 'logger.info(f"Saved best adapter')
|
| 166 |
+
|
| 167 |
+
fixes_applied.append("Replaced print with logger calls")
|
| 168 |
+
|
| 169 |
+
# Write fixed content
|
| 170 |
+
with open(filepath, 'w') as f:
|
| 171 |
+
f.write(content)
|
| 172 |
+
|
| 173 |
+
print("\n" + "="*80)
|
| 174 |
+
print("DPO TRAINER - FIXES APPLIED")
|
| 175 |
+
print("="*80)
|
| 176 |
+
for i, fix in enumerate(fixes_applied, 1):
|
| 177 |
+
print(f"{i}. β
{fix}")
|
| 178 |
+
print("="*80)
|
| 179 |
+
print(f"\nβ
All fixes applied successfully to {filepath}")
|
| 180 |
+
print(f"π Original backed up to {filepath}.backup")
|
| 181 |
+
print("\nTo verify: python run_dpo.py --config config_dpo.yaml")
|
| 182 |
+
|
| 183 |
+
return True
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
import sys
|
| 187 |
+
|
| 188 |
+
filepath = sys.argv[1] if len(sys.argv) > 1 else "run_dpo.py"
|
| 189 |
+
|
| 190 |
+
print("DPO Trainer - Quick Fix Script")
|
| 191 |
+
print("="*80)
|
| 192 |
+
print("This script will apply the following critical fixes:")
|
| 193 |
+
print(" 1. Add memory cleanup (gc.collect, torch.cuda.empty_cache)")
|
| 194 |
+
print(" 2. Add logging setup")
|
| 195 |
+
print(" 3. Add custom exceptions (DataFormattingError, DataValidationError)")
|
| 196 |
+
print(" 4. Add data validation function")
|
| 197 |
+
print(" 5. Add TRL version check")
|
| 198 |
+
print(" 6. Replace print with logger")
|
| 199 |
+
print("="*80)
|
| 200 |
+
print()
|
| 201 |
+
|
| 202 |
+
response = input("Apply fixes? [y/N]: ")
|
| 203 |
+
if response.lower() == 'y':
|
| 204 |
+
apply_fixes(filepath)
|
| 205 |
+
else:
|
| 206 |
+
print("Cancelled")
|
DPO-14b/config_dpo.yaml
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
run_dir: "./runs/dpo_run_14b_v1"
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# WandB integration for experiment tracking
|
| 6 |
+
wandb:
|
| 7 |
+
enabled: true
|
| 8 |
+
project: "dpo-training"
|
| 9 |
+
entity: null
|
| 10 |
+
name: null
|
| 11 |
+
tags: ["dpo-lora", "preference-optimization"]
|
| 12 |
+
notes: null
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
# Use the SFT model as base
|
| 16 |
+
repo_id: "../../Models/Qwen2.5-Coder-14B-CPT-SFT"
|
| 17 |
+
revision: null
|
| 18 |
+
|
| 19 |
+
# Used only when repo_id is a HF repo (not a local path)
|
| 20 |
+
base_local_dir: "base_model"
|
| 21 |
+
|
| 22 |
+
trust_remote_code: true
|
| 23 |
+
tokenizer_use_fast: true
|
| 24 |
+
device_map: "auto"
|
| 25 |
+
|
| 26 |
+
torch_dtype: "bfloat16" # "float16" | "bfloat16" | "float32"
|
| 27 |
+
|
| 28 |
+
# QLoRA
|
| 29 |
+
use_4bit: false
|
| 30 |
+
bnb_4bit_quant_type: "nf4"
|
| 31 |
+
bnb_4bit_use_double_quant: false
|
| 32 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 33 |
+
|
| 34 |
+
# optional: "flash_attention_2" | "sdpa" | null
|
| 35 |
+
attn_implementation: null
|
| 36 |
+
|
| 37 |
+
data:
|
| 38 |
+
train_jsonl: "dpo_pairs_generated.jsonl"
|
| 39 |
+
eval_jsonl: null
|
| 40 |
+
eval_split_ratio: 0.1
|
| 41 |
+
|
| 42 |
+
# Field names in your JSONL data for DPO
|
| 43 |
+
# DPO requires: prompt, chosen, rejected
|
| 44 |
+
prompt_field: "prompt"
|
| 45 |
+
chosen_field: "chosen"
|
| 46 |
+
rejected_field: "rejected"
|
| 47 |
+
|
| 48 |
+
# If you have a file-level F1 score field for ranking
|
| 49 |
+
score_field: "f1_score" # Optional: used for ranking if available
|
| 50 |
+
|
| 51 |
+
# Formatting options
|
| 52 |
+
format_type: "chatml" # "chatml" | "alpaca" | "custom"
|
| 53 |
+
|
| 54 |
+
# System prompt to prepend to all prompts
|
| 55 |
+
system_prompt: |
|
| 56 |
+
You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
|
| 57 |
+
|
| 58 |
+
## Output Format
|
| 59 |
+
|
| 60 |
+
##OUTPUT
|
| 61 |
+
Explain the data flow and why each component must change:
|
| 62 |
+
- Flow: [Input β Processing β Output with arrows]
|
| 63 |
+
- For each component: "The [ComponentName] ([path]) must [action] because [reason]βwithout this, [consequence]"
|
| 64 |
+
- Explain coupling between components
|
| 65 |
+
|
| 66 |
+
##SELECT
|
| 67 |
+
modify::crates/path/to/file.rs::impl::ComponentName
|
| 68 |
+
add::crates/another/file.rs::function::AnotherComponent
|
| 69 |
+
<EOS>
|
| 70 |
+
|
| 71 |
+
## Rules
|
| 72 |
+
|
| 73 |
+
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
|
| 74 |
+
2. Use `::` for nested items: `status::StructName::Type::Name`
|
| 75 |
+
3. Always explain "must change because" and "without this"
|
| 76 |
+
3. Types of components: function, struct, enum, impl, trait
|
| 77 |
+
4. If there is extra information (e.g., enum variants), include that too.
|
| 78 |
+
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
|
| 79 |
+
|
| 80 |
+
max_length: 2048
|
| 81 |
+
shuffle: true
|
| 82 |
+
num_proc: 4
|
| 83 |
+
|
| 84 |
+
peft:
|
| 85 |
+
enabled: true
|
| 86 |
+
r: 16
|
| 87 |
+
lora_alpha: 32
|
| 88 |
+
lora_dropout: 0.05
|
| 89 |
+
bias: "none"
|
| 90 |
+
target_modules: "auto"
|
| 91 |
+
|
| 92 |
+
# DPO specific parameters
|
| 93 |
+
dpo:
|
| 94 |
+
beta: 0.1 # Temperature parameter for DPO loss (higher = less aggressive)
|
| 95 |
+
label_smoothing: 0.0 # Label smoothing for DPO
|
| 96 |
+
loss_type: "sigmoid" # "sigmoid" | "hinge" | "ipo" | "kto"
|
| 97 |
+
|
| 98 |
+
# Reference model settings
|
| 99 |
+
use_reference_model: true # If false, uses frozen copy of initial model
|
| 100 |
+
reference_free: false # If true, doesn't use reference model at all
|
| 101 |
+
|
| 102 |
+
train:
|
| 103 |
+
num_train_epochs: 3
|
| 104 |
+
|
| 105 |
+
per_device_train_batch_size: 1
|
| 106 |
+
per_device_eval_batch_size: 1
|
| 107 |
+
gradient_accumulation_steps: 8
|
| 108 |
+
|
| 109 |
+
learning_rate: 5e-5 # Lower than SFT for stability
|
| 110 |
+
weight_decay: 0.0
|
| 111 |
+
warmup_ratio: 0.1
|
| 112 |
+
lr_scheduler_type: "cosine"
|
| 113 |
+
|
| 114 |
+
optim: "adamw_torch"
|
| 115 |
+
max_grad_norm: 1.0
|
| 116 |
+
gradient_checkpointing: true
|
| 117 |
+
|
| 118 |
+
logging_steps: 2
|
| 119 |
+
save_strategy: "steps"
|
| 120 |
+
save_steps: 100
|
| 121 |
+
save_total_limit: 10
|
| 122 |
+
|
| 123 |
+
evaluation_strategy: "steps"
|
| 124 |
+
eval_steps: 25
|
| 125 |
+
load_best_model_at_end: true
|
| 126 |
+
|
| 127 |
+
# Early stopping
|
| 128 |
+
early_stopping:
|
| 129 |
+
enabled: true
|
| 130 |
+
patience: 5
|
| 131 |
+
min_delta: 0.001
|
| 132 |
+
metric: "eval_loss"
|
| 133 |
+
mode: "min"
|
| 134 |
+
|
| 135 |
+
resume_from_checkpoint: "auto"
|
| 136 |
+
|
| 137 |
+
merge:
|
| 138 |
+
enabled: true
|
| 139 |
+
merged_dtype: "float16"
|
| 140 |
+
max_shard_size: "2GB"
|
| 141 |
+
output_dir: "./merged_14b_dpo_lora"
|
DPO-14b/create_synthetic_pairs.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick script to convert SFT data to DPO format for training.
|
| 3 |
+
Since we don't have multiple model generations, we'll create synthetic pairs
|
| 4 |
+
by using the ground truth as "chosen" and creating degraded versions as "rejected".
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import random
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
sys.path.append(str(Path(__file__).parent))
|
| 13 |
+
from f1_score_utils import compute_file_level_f1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def degrade_output(output: str) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Create a degraded version of the output by:
|
| 19 |
+
1. Removing some file selections
|
| 20 |
+
2. Adding incorrect file selections
|
| 21 |
+
3. Keeping the explanation but modifying selections
|
| 22 |
+
"""
|
| 23 |
+
# Split into OUTPUT and SELECT sections
|
| 24 |
+
if "##SELECT" not in output:
|
| 25 |
+
return output
|
| 26 |
+
|
| 27 |
+
parts = output.split("##SELECT")
|
| 28 |
+
explanation = parts[0]
|
| 29 |
+
select_section = parts[1].split("<EOS>")[0] if "<EOS>" in parts[1] else parts[1]
|
| 30 |
+
|
| 31 |
+
# Extract file selections
|
| 32 |
+
lines = [l.strip() for l in select_section.strip().split('\n') if l.strip()]
|
| 33 |
+
|
| 34 |
+
if len(lines) <= 1:
|
| 35 |
+
return output # Can't degrade further
|
| 36 |
+
|
| 37 |
+
# Strategy: randomly remove 1-2 files OR add a random incorrect file
|
| 38 |
+
strategy = random.choice(['remove', 'add', 'replace'])
|
| 39 |
+
|
| 40 |
+
if strategy == 'remove' and len(lines) > 1:
|
| 41 |
+
# Remove 1-2 files
|
| 42 |
+
num_to_remove = min(random.randint(1, 2), len(lines) - 1)
|
| 43 |
+
new_lines = random.sample(lines, len(lines) - num_to_remove)
|
| 44 |
+
elif strategy == 'add':
|
| 45 |
+
# Add an incorrect file
|
| 46 |
+
fake_files = [
|
| 47 |
+
"crates/router/src/handlers/utils.rs::helper_function",
|
| 48 |
+
"crates/api_models/src/types.rs::RequestType",
|
| 49 |
+
"crates/common_utils/src/helpers.rs::parse_data",
|
| 50 |
+
"crates/diesel_models/src/schema.rs::table_definition",
|
| 51 |
+
]
|
| 52 |
+
new_lines = lines + [random.choice(fake_files)]
|
| 53 |
+
else: # replace
|
| 54 |
+
# Replace one file with incorrect one
|
| 55 |
+
if len(lines) > 0:
|
| 56 |
+
idx = random.randint(0, len(lines) - 1)
|
| 57 |
+
fake_files = [
|
| 58 |
+
"crates/router/src/handlers/utils.rs::helper_function",
|
| 59 |
+
"crates/api_models/src/types.rs::RequestType",
|
| 60 |
+
"crates/common_utils/src/helpers.rs::parse_data",
|
| 61 |
+
]
|
| 62 |
+
new_lines = lines.copy()
|
| 63 |
+
new_lines[idx] = random.choice(fake_files)
|
| 64 |
+
else:
|
| 65 |
+
new_lines = lines
|
| 66 |
+
|
| 67 |
+
# Reconstruct output
|
| 68 |
+
new_select = "\n".join(new_lines)
|
| 69 |
+
return f"{explanation}##SELECT\n{new_select}\n<EOS>"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def create_dpo_pairs(input_jsonl: str, output_jsonl: str, max_examples: int = None):
|
| 73 |
+
"""
|
| 74 |
+
Convert SFT data to DPO format by creating synthetic degraded versions.
|
| 75 |
+
"""
|
| 76 |
+
pairs_created = 0
|
| 77 |
+
examples_processed = 0
|
| 78 |
+
|
| 79 |
+
with open(input_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out:
|
| 80 |
+
for line in f_in:
|
| 81 |
+
if max_examples and examples_processed >= max_examples:
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
data = json.loads(line)
|
| 86 |
+
except:
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
prompt = data.get("input", "")
|
| 90 |
+
ground_truth = data.get("output", "")
|
| 91 |
+
|
| 92 |
+
if not prompt or not ground_truth or "##SELECT" not in ground_truth:
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
# Create 2-3 degraded versions
|
| 96 |
+
num_degraded = random.randint(2, 3)
|
| 97 |
+
for _ in range(num_degraded):
|
| 98 |
+
degraded = degrade_output(ground_truth)
|
| 99 |
+
|
| 100 |
+
# Compute F1 scores
|
| 101 |
+
gt_metrics = compute_file_level_f1(ground_truth, ground_truth)
|
| 102 |
+
deg_metrics = compute_file_level_f1(degraded, ground_truth)
|
| 103 |
+
|
| 104 |
+
# Only create pair if there's a significant difference
|
| 105 |
+
if gt_metrics["f1"] - deg_metrics["f1"] >= 0.1:
|
| 106 |
+
pair = {
|
| 107 |
+
"prompt": prompt,
|
| 108 |
+
"chosen": ground_truth,
|
| 109 |
+
"rejected": degraded,
|
| 110 |
+
"chosen_f1": gt_metrics["f1"],
|
| 111 |
+
"rejected_f1": deg_metrics["f1"]
|
| 112 |
+
}
|
| 113 |
+
f_out.write(json.dumps(pair) + '\n')
|
| 114 |
+
pairs_created += 1
|
| 115 |
+
|
| 116 |
+
examples_processed += 1
|
| 117 |
+
if examples_processed % 100 == 0:
|
| 118 |
+
print(f"Processed {examples_processed} examples, created {pairs_created} pairs")
|
| 119 |
+
|
| 120 |
+
print(f"\nDone! Processed {examples_processed} examples")
|
| 121 |
+
print(f"Created {pairs_created} DPO pairs")
|
| 122 |
+
print(f"Average pairs per example: {pairs_created / max(examples_processed, 1):.2f}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
import argparse
|
| 127 |
+
|
| 128 |
+
parser = argparse.ArgumentParser()
|
| 129 |
+
parser.add_argument("--input", default="../../sft_output.jsonl")
|
| 130 |
+
parser.add_argument("--output", default="dpo_pairs_generated.jsonl")
|
| 131 |
+
parser.add_argument("--max-examples", type=int, default=None)
|
| 132 |
+
args = parser.parse_args()
|
| 133 |
+
|
| 134 |
+
print(f"Converting {args.input} to DPO format...")
|
| 135 |
+
create_dpo_pairs(args.input, args.output, args.max_examples)
|
| 136 |
+
print(f"Output saved to: {args.output}")
|
DPO-14b/dpo_dataset.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt": "##TASK\nAdd subscription support to webhook event metadata and outgoing webhook content systems to enable proper subscription lifecycle event handling and webhook delivery for subscription-related operations.", "chosen": "##OUTPUT\nThe webhook system requires subscription event support to handle subscription lifecycle events (creation, payment, cancellation) and deliver them to merchants via outgoing webhooks. This affects the event metadata enumeration and webhook content transformation layers. The complete flow is: subscription event occurs β EventMetadata::Subscription variant created with subscription_id, invoice_id, payment_id β stored in database β retrieved for webhook delivery β OutgoingWebhookEventContent::foreign_from converts EventMetadata to OutgoingWebhookEventContent::Subscription β StripeWebhookObject::From transforms to StripeWebhookObject::Subscriptions for Stripe compatibility. The OutgoingWebhookEventContent::foreign_from function (router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from) must add the Subscription variant match arm because it's the bridge from storage EventMetadata to webhook content representationβwithout this, subscription events cannot be converted to webhook content and would be silently dropped during webhook processing. The StripeWebhookObject::From implementation (router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From) must handle OutgoingWebhookContent::SubscriptionDetails because it provides Stripe-compatible webhook formattingβwithout this, merchants using Stripe webhook endpoints would receive malformed or missing subscription events.\n\n##SELECT\ncrates/router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from\ncrates/router/src/compatibility/stripe/webhooks.rs::StripeWebhookObject::From\n<EOS>", "rejected": "##OUTPUT\nThe webhook system needs to support subscriptions by adding new event types. The subscription events need to be handled in the webhook processing pipeline.\n\n##SELECT\ncrates/router/src/core/webhooks/outgoing_v2.rs::OutgoingWebhookEventContent::foreign_from\ncrates/api_models/src/webhooks.rs::WebhookEvent\ncrates/router/src/handlers/webhooks.rs::handle_webhook\n<EOS>", "chosen_f1": 1.0, "rejected_f1": 0.5}
|
| 2 |
+
{"prompt": "##TASK\nUpdate subscription response mapping to use common_enums::SubscriptionStatus instead of subscription_types::SubscriptionStatus for consistent enum usage across the codebase.", "chosen": "##OUTPUT\nThe subscription system needs to standardize on a single SubscriptionStatus enum type to avoid type inconsistencies and import conflicts. The current flow is: subscription handler creates SubscriptionResponse β calls to_subscription_response method β converts string status to subscription_types::SubscriptionStatus β returns response. The to_subscription_response method (crates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response) must change because it's the conversion point where the status enum is instantiated. Using common_enums::SubscriptionStatus ensures consistency with other parts of the system that already use this common enum type. Without this change, the subscription module would maintain its own duplicate enum, creating potential mismatches and requiring conversions throughout the codebase when interacting with other modules that expect the common enum type.\n\n##SELECT\ncrates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response\n<EOS>", "rejected": "##OUTPUT\nThe subscription response should use the common enum type for consistency. Need to update the import statements and change the enum usage.\n\n##SELECT\ncrates/subscriptions/src/core/subscription_handler.rs::SubscriptionWithHandler::to_subscription_response\ncrates/subscriptions/src/types.rs::SubscriptionStatus\n<EOS>", "chosen_f1": 1.0, "rejected_f1": 0.67}
|
| 3 |
+
{"prompt": "##TASK\nAdd subscription support to webhook event system by including Subscriptions in EventClass enum and implementing SubscriptionStatus to EventType mapping for proper webhook routing.", "chosen": "##OUTPUT\nThe webhook system needs to handle subscription-related events for proper event classification and routing. The current flow is: webhook event received β EventClass determines routing category β EventType specifies the exact event type β handlers process based on event type. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must include Subscriptions because it's the primary classification that determines which subsystem handles the event. The SubscriptionStatus transformer (crates/common_enums/src/transformers.rs::impl::Option<EventType>) must map subscription status changes to appropriate EventType values because it converts domain state changes into webhook events. Without EventClass::Subscriptions, subscription events would be misclassified or unhandled. Without the SubscriptionStatus mapping, subscription status changes wouldn't trigger appropriate webhook notifications, breaking integration with external systems that depend on subscription lifecycle events.\n\n##SELECT\ncrates/common_enums/src/enums.rs::EventClass\ncrates/common_enums/src/transformers.rs::impl::Option<EventType>\n<EOS>", "rejected": "##OUTPUT\nSubscription support requires adding the Subscriptions variant to EventClass and mapping status changes to events.\n\n##SELECT\ncrates/common_enums/src/enums.rs::EventClass\ncrates/common_enums/src/enums.rs::EventType\ncrates/common_enums/src/transformers.rs::impl::Option<EventType>\n<EOS>", "chosen_f1": 1.0, "rejected_f1": 0.75}
|
DPO-14b/dpo_pairs_generated.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf40fc92ca4b73f865da618cb023b565e98bd83abf64ba083ae113c3173b93b3
|
| 3 |
+
size 39769672
|
DPO-14b/f1_score_utils.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility for computing F1 scores at file level for ranking generated outputs.
|
| 3 |
+
This helps create preference pairs for DPO training.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
from typing import List, Set, Tuple, Dict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extract_files_from_selection(output_text: str) -> Set[str]:
|
| 13 |
+
"""
|
| 14 |
+
Extract file paths from ##SELECT section.
|
| 15 |
+
Expected format: modify::crates/path/to/file.rs::impl::ComponentName
|
| 16 |
+
Returns set of unique file paths.
|
| 17 |
+
"""
|
| 18 |
+
files = set()
|
| 19 |
+
|
| 20 |
+
# Find ##SELECT section
|
| 21 |
+
select_match = re.search(r'##SELECT\s*(.*?)<EOS>', output_text, re.DOTALL | re.IGNORECASE)
|
| 22 |
+
if not select_match:
|
| 23 |
+
return files
|
| 24 |
+
|
| 25 |
+
select_section = select_match.group(1)
|
| 26 |
+
|
| 27 |
+
# Extract file paths from each line
|
| 28 |
+
# Format: action::path::type::name
|
| 29 |
+
for line in select_section.strip().split('\n'):
|
| 30 |
+
line = line.strip()
|
| 31 |
+
if not line:
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
# Split by :: and extract the file path (second component)
|
| 35 |
+
parts = line.split('::')
|
| 36 |
+
if len(parts) >= 2:
|
| 37 |
+
file_path = parts[1]
|
| 38 |
+
files.add(file_path)
|
| 39 |
+
|
| 40 |
+
return files
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_file_level_f1(predicted: str, ground_truth: str) -> Dict[str, float]:
|
| 44 |
+
"""
|
| 45 |
+
Compute F1 score based on file-level predictions.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
predicted: Model output with ##SELECT section
|
| 49 |
+
ground_truth: Ground truth output with ##SELECT section
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Dictionary with precision, recall, f1 scores
|
| 53 |
+
"""
|
| 54 |
+
pred_files = extract_files_from_selection(predicted)
|
| 55 |
+
gt_files = extract_files_from_selection(ground_truth)
|
| 56 |
+
|
| 57 |
+
if len(gt_files) == 0:
|
| 58 |
+
# No ground truth files
|
| 59 |
+
if len(pred_files) == 0:
|
| 60 |
+
return {"precision": 1.0, "recall": 1.0, "f1": 1.0}
|
| 61 |
+
else:
|
| 62 |
+
return {"precision": 0.0, "recall": 1.0, "f1": 0.0}
|
| 63 |
+
|
| 64 |
+
if len(pred_files) == 0:
|
| 65 |
+
# No predicted files but have ground truth
|
| 66 |
+
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
| 67 |
+
|
| 68 |
+
# Calculate metrics
|
| 69 |
+
true_positives = len(pred_files & gt_files)
|
| 70 |
+
false_positives = len(pred_files - gt_files)
|
| 71 |
+
false_negatives = len(gt_files - pred_files)
|
| 72 |
+
|
| 73 |
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
|
| 74 |
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
|
| 75 |
+
|
| 76 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"precision": precision,
|
| 80 |
+
"recall": recall,
|
| 81 |
+
"f1": f1,
|
| 82 |
+
"true_positives": true_positives,
|
| 83 |
+
"false_positives": false_positives,
|
| 84 |
+
"false_negatives": false_negatives,
|
| 85 |
+
"pred_files": list(pred_files),
|
| 86 |
+
"gt_files": list(gt_files),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def rank_outputs_by_f1(outputs: List[str], ground_truth: str) -> List[Tuple[str, float, Dict]]:
|
| 91 |
+
"""
|
| 92 |
+
Rank multiple outputs by their F1 scores compared to ground truth.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
outputs: List of model outputs to rank
|
| 96 |
+
ground_truth: Ground truth output
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
List of tuples: (output, f1_score, metrics_dict) sorted by F1 descending
|
| 100 |
+
"""
|
| 101 |
+
ranked = []
|
| 102 |
+
for output in outputs:
|
| 103 |
+
metrics = compute_file_level_f1(output, ground_truth)
|
| 104 |
+
ranked.append((output, metrics["f1"], metrics))
|
| 105 |
+
|
| 106 |
+
# Sort by F1 score descending
|
| 107 |
+
ranked.sort(key=lambda x: x[1], reverse=True)
|
| 108 |
+
return ranked
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def create_dpo_pairs_from_generations(
|
| 112 |
+
prompt: str,
|
| 113 |
+
generations: List[str],
|
| 114 |
+
ground_truth: str,
|
| 115 |
+
min_f1_difference: float = 0.1
|
| 116 |
+
) -> List[Dict[str, str]]:
|
| 117 |
+
"""
|
| 118 |
+
Create DPO training pairs from multiple generations.
|
| 119 |
+
Uses F1 score to determine which generation is better.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
prompt: Input prompt/task
|
| 123 |
+
generations: List of generated outputs
|
| 124 |
+
ground_truth: Ground truth output
|
| 125 |
+
min_f1_difference: Minimum F1 difference to create a pair
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List of DPO pairs: {"prompt": str, "chosen": str, "rejected": str}
|
| 129 |
+
"""
|
| 130 |
+
if len(generations) < 2:
|
| 131 |
+
return []
|
| 132 |
+
|
| 133 |
+
ranked = rank_outputs_by_f1(generations, ground_truth)
|
| 134 |
+
pairs = []
|
| 135 |
+
|
| 136 |
+
# Create pairs from ranked outputs
|
| 137 |
+
for i in range(len(ranked)):
|
| 138 |
+
for j in range(i + 1, len(ranked)):
|
| 139 |
+
better_output, better_f1, _ = ranked[i]
|
| 140 |
+
worse_output, worse_f1, _ = ranked[j]
|
| 141 |
+
|
| 142 |
+
# Only create pair if F1 difference is significant
|
| 143 |
+
if better_f1 - worse_f1 >= min_f1_difference:
|
| 144 |
+
pairs.append({
|
| 145 |
+
"prompt": prompt,
|
| 146 |
+
"chosen": better_output,
|
| 147 |
+
"rejected": worse_output,
|
| 148 |
+
"chosen_f1": better_f1,
|
| 149 |
+
"rejected_f1": worse_f1,
|
| 150 |
+
})
|
| 151 |
+
|
| 152 |
+
return pairs
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def convert_sft_to_dpo_with_sampling(
|
| 156 |
+
sft_jsonl_path: str,
|
| 157 |
+
output_jsonl_path: str,
|
| 158 |
+
model_inference_fn,
|
| 159 |
+
num_samples: int = 4,
|
| 160 |
+
min_f1_difference: float = 0.1,
|
| 161 |
+
temperature: float = 0.8
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Convert SFT dataset to DPO dataset by sampling multiple outputs and ranking by F1.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
sft_jsonl_path: Path to SFT JSONL file
|
| 168 |
+
output_jsonl_path: Path to output DPO JSONL file
|
| 169 |
+
model_inference_fn: Function that takes (prompt, num_samples, temperature) and returns List[str]
|
| 170 |
+
num_samples: Number of outputs to sample per prompt
|
| 171 |
+
min_f1_difference: Minimum F1 difference to create a pair
|
| 172 |
+
temperature: Sampling temperature
|
| 173 |
+
"""
|
| 174 |
+
pairs_created = 0
|
| 175 |
+
|
| 176 |
+
with open(sft_jsonl_path, 'r') as f_in, open(output_jsonl_path, 'w') as f_out:
|
| 177 |
+
for line in f_in:
|
| 178 |
+
data = json.loads(line)
|
| 179 |
+
|
| 180 |
+
# Extract prompt and ground truth
|
| 181 |
+
prompt = data.get("input", "")
|
| 182 |
+
ground_truth = data.get("output", "")
|
| 183 |
+
|
| 184 |
+
if not prompt or not ground_truth:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
# Generate multiple outputs
|
| 188 |
+
try:
|
| 189 |
+
generations = model_inference_fn(prompt, num_samples, temperature)
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error generating outputs: {e}")
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
# Create DPO pairs
|
| 195 |
+
pairs = create_dpo_pairs_from_generations(
|
| 196 |
+
prompt, generations, ground_truth, min_f1_difference
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Write pairs to output
|
| 200 |
+
for pair in pairs:
|
| 201 |
+
f_out.write(json.dumps(pair) + '\n')
|
| 202 |
+
pairs_created += 1
|
| 203 |
+
|
| 204 |
+
print(f"Created {pairs_created} DPO pairs from {sft_jsonl_path}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def prepare_dpo_data_from_instruct(
|
| 208 |
+
instruct_jsonl: str,
|
| 209 |
+
output_dpo_jsonl: str,
|
| 210 |
+
):
|
| 211 |
+
"""
|
| 212 |
+
Simple conversion from instruction data to DPO format.
|
| 213 |
+
This assumes you already have multiple outputs per input or will generate them.
|
| 214 |
+
|
| 215 |
+
For demonstration, this creates a basic structure. In practice, you need to:
|
| 216 |
+
1. Generate multiple outputs for each input
|
| 217 |
+
2. Rank them by F1 score
|
| 218 |
+
3. Create chosen/rejected pairs
|
| 219 |
+
"""
|
| 220 |
+
print(f"Converting {instruct_jsonl} to DPO format...")
|
| 221 |
+
print("Note: This requires generating multiple outputs per prompt.")
|
| 222 |
+
print("Use convert_sft_to_dpo_with_sampling() with your model for actual conversion.")
|
| 223 |
+
|
| 224 |
+
# Example structure - you'll need to fill this with actual generations
|
| 225 |
+
with open(instruct_jsonl, 'r') as f:
|
| 226 |
+
for line in f:
|
| 227 |
+
data = json.loads(line)
|
| 228 |
+
print(f"Input: {data.get('input', '')[:100]}...")
|
| 229 |
+
print(f"Ground truth output available: {len(data.get('output', ''))} chars")
|
| 230 |
+
print(" -> Need to generate multiple outputs and rank by F1 score")
|
| 231 |
+
print()
|
| 232 |
+
break # Just show one example
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
# Example usage
|
| 237 |
+
print("F1 Score Utility for File-Level Ranking")
|
| 238 |
+
print("=" * 50)
|
| 239 |
+
|
| 240 |
+
# Example 1: Compute F1 for two outputs
|
| 241 |
+
ground_truth = """
|
| 242 |
+
##OUTPUT
|
| 243 |
+
The webhook system requires subscription support.
|
| 244 |
+
##SELECT
|
| 245 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 246 |
+
crates/router/src/webhooks.rs::process_webhook
|
| 247 |
+
<EOS>
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
prediction1 = """
|
| 251 |
+
##OUTPUT
|
| 252 |
+
The webhook system requires subscription support.
|
| 253 |
+
##SELECT
|
| 254 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 255 |
+
crates/router/src/webhooks.rs::process_webhook
|
| 256 |
+
<EOS>
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
prediction2 = """
|
| 260 |
+
##OUTPUT
|
| 261 |
+
The webhook system requires subscription support.
|
| 262 |
+
##SELECT
|
| 263 |
+
crates/common_enums/src/enums.rs::EventClass
|
| 264 |
+
crates/router/src/handlers.rs::handle_request
|
| 265 |
+
<EOS>
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
print("\nExample 1: Perfect match")
|
| 269 |
+
metrics1 = compute_file_level_f1(prediction1, ground_truth)
|
| 270 |
+
print(f"F1 Score: {metrics1['f1']:.3f}")
|
| 271 |
+
print(f"Precision: {metrics1['precision']:.3f}, Recall: {metrics1['recall']:.3f}")
|
| 272 |
+
|
| 273 |
+
print("\nExample 2: Partial match")
|
| 274 |
+
metrics2 = compute_file_level_f1(prediction2, ground_truth)
|
| 275 |
+
print(f"F1 Score: {metrics2['f1']:.3f}")
|
| 276 |
+
print(f"Precision: {metrics2['precision']:.3f}, Recall: {metrics2['recall']:.3f}")
|
| 277 |
+
|
| 278 |
+
print("\nExample 3: Ranking outputs")
|
| 279 |
+
outputs = [prediction1, prediction2]
|
| 280 |
+
ranked = rank_outputs_by_f1(outputs, ground_truth)
|
| 281 |
+
print("Ranked outputs:")
|
| 282 |
+
for i, (output, f1, metrics) in enumerate(ranked, 1):
|
| 283 |
+
print(f" {i}. F1={f1:.3f} - {metrics['true_positives']} correct files")
|
DPO-14b/prepare_data.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data preparation utilities for converting SFT data to DPO/GRPO formats.
|
| 3 |
+
This script helps generate multiple outputs and create preference/ranking datasets.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import argparse
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Dict
|
| 10 |
+
from f1_score_utils import (
|
| 11 |
+
compute_file_level_f1,
|
| 12 |
+
rank_outputs_by_f1,
|
| 13 |
+
create_dpo_pairs_from_generations
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_model_for_generation(model_path: str):
|
| 18 |
+
"""
|
| 19 |
+
Load a model for generation. This is a placeholder - implement based on your setup.
|
| 20 |
+
"""
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
print(f"Loading model from {model_path}...")
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 26 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 27 |
+
model_path,
|
| 28 |
+
torch_dtype=torch.bfloat16,
|
| 29 |
+
device_map="auto"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return model, tokenizer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def generate_multiple_outputs(
|
| 36 |
+
model,
|
| 37 |
+
tokenizer,
|
| 38 |
+
prompt: str,
|
| 39 |
+
num_samples: int = 4,
|
| 40 |
+
temperatures: List[float] = None,
|
| 41 |
+
max_new_tokens: int = 512
|
| 42 |
+
) -> List[str]:
|
| 43 |
+
"""
|
| 44 |
+
Generate multiple outputs for a single prompt using different temperatures.
|
| 45 |
+
"""
|
| 46 |
+
if temperatures is None:
|
| 47 |
+
temperatures = [0.6, 0.8, 1.0, 1.2][:num_samples]
|
| 48 |
+
|
| 49 |
+
outputs = []
|
| 50 |
+
for temp in temperatures:
|
| 51 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 52 |
+
|
| 53 |
+
generated = model.generate(
|
| 54 |
+
**inputs,
|
| 55 |
+
max_new_tokens=max_new_tokens,
|
| 56 |
+
temperature=temp,
|
| 57 |
+
do_sample=True,
|
| 58 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Extract only the new tokens (not the prompt)
|
| 62 |
+
output_text = tokenizer.decode(
|
| 63 |
+
generated[0][inputs.input_ids.shape[1]:],
|
| 64 |
+
skip_special_tokens=True
|
| 65 |
+
)
|
| 66 |
+
outputs.append(output_text)
|
| 67 |
+
|
| 68 |
+
return outputs
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def convert_sft_to_dpo(
|
| 72 |
+
sft_jsonl: str,
|
| 73 |
+
output_jsonl: str,
|
| 74 |
+
model_path: str = None,
|
| 75 |
+
num_samples: int = 4,
|
| 76 |
+
min_f1_difference: float = 0.1,
|
| 77 |
+
max_examples: int = None
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Convert SFT dataset to DPO format by generating multiple outputs and creating pairs.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
sft_jsonl: Path to SFT JSONL file
|
| 84 |
+
output_jsonl: Path to output DPO JSONL file
|
| 85 |
+
model_path: Path to model for generation (if None, you need pre-generated outputs)
|
| 86 |
+
num_samples: Number of outputs to generate per prompt
|
| 87 |
+
min_f1_difference: Minimum F1 difference to create a pair
|
| 88 |
+
max_examples: Maximum number of examples to process (None = all)
|
| 89 |
+
"""
|
| 90 |
+
if model_path:
|
| 91 |
+
model, tokenizer = load_model_for_generation(model_path)
|
| 92 |
+
else:
|
| 93 |
+
print("Warning: No model path provided. Expecting pre-generated outputs in data.")
|
| 94 |
+
model, tokenizer = None, None
|
| 95 |
+
|
| 96 |
+
pairs_created = 0
|
| 97 |
+
examples_processed = 0
|
| 98 |
+
|
| 99 |
+
with open(sft_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out:
|
| 100 |
+
for line in f_in:
|
| 101 |
+
if max_examples and examples_processed >= max_examples:
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
data = json.loads(line)
|
| 105 |
+
prompt = data.get("input", "")
|
| 106 |
+
ground_truth = data.get("output", "")
|
| 107 |
+
|
| 108 |
+
if not prompt or not ground_truth:
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# Generate multiple outputs
|
| 112 |
+
if model and tokenizer:
|
| 113 |
+
try:
|
| 114 |
+
generations = generate_multiple_outputs(
|
| 115 |
+
model, tokenizer, prompt, num_samples
|
| 116 |
+
)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Error generating outputs: {e}")
|
| 119 |
+
continue
|
| 120 |
+
else:
|
| 121 |
+
# Expect outputs in the data
|
| 122 |
+
generations = data.get("outputs", [])
|
| 123 |
+
if len(generations) < 2:
|
| 124 |
+
print(f"Skipping example: need at least 2 outputs")
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
# Create DPO pairs
|
| 128 |
+
pairs = create_dpo_pairs_from_generations(
|
| 129 |
+
prompt, generations, ground_truth, min_f1_difference
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Write pairs to output
|
| 133 |
+
for pair in pairs:
|
| 134 |
+
f_out.write(json.dumps(pair) + '\n')
|
| 135 |
+
pairs_created += 1
|
| 136 |
+
|
| 137 |
+
examples_processed += 1
|
| 138 |
+
if examples_processed % 10 == 0:
|
| 139 |
+
print(f"Processed {examples_processed} examples, created {pairs_created} pairs")
|
| 140 |
+
|
| 141 |
+
print(f"\nDone! Processed {examples_processed} examples, created {pairs_created} DPO pairs")
|
| 142 |
+
print(f"Output saved to: {output_jsonl}")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def convert_sft_to_grpo(
|
| 146 |
+
sft_jsonl: str,
|
| 147 |
+
output_jsonl: str,
|
| 148 |
+
model_path: str = None,
|
| 149 |
+
num_samples: int = 4,
|
| 150 |
+
max_examples: int = None
|
| 151 |
+
):
|
| 152 |
+
"""
|
| 153 |
+
Convert SFT dataset to GRPO format by generating multiple outputs and computing scores.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
sft_jsonl: Path to SFT JSONL file
|
| 157 |
+
output_jsonl: Path to output GRPO JSONL file
|
| 158 |
+
model_path: Path to model for generation
|
| 159 |
+
num_samples: Number of outputs to generate per prompt
|
| 160 |
+
max_examples: Maximum number of examples to process (None = all)
|
| 161 |
+
"""
|
| 162 |
+
if model_path:
|
| 163 |
+
model, tokenizer = load_model_for_generation(model_path)
|
| 164 |
+
else:
|
| 165 |
+
print("Warning: No model path provided. Expecting pre-generated outputs in data.")
|
| 166 |
+
model, tokenizer = None, None
|
| 167 |
+
|
| 168 |
+
examples_created = 0
|
| 169 |
+
examples_processed = 0
|
| 170 |
+
|
| 171 |
+
with open(sft_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out:
|
| 172 |
+
for line in f_in:
|
| 173 |
+
if max_examples and examples_processed >= max_examples:
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
data = json.loads(line)
|
| 177 |
+
prompt = data.get("input", "")
|
| 178 |
+
ground_truth = data.get("output", "")
|
| 179 |
+
|
| 180 |
+
if not prompt or not ground_truth:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
# Generate multiple outputs
|
| 184 |
+
if model and tokenizer:
|
| 185 |
+
try:
|
| 186 |
+
generations = generate_multiple_outputs(
|
| 187 |
+
model, tokenizer, prompt, num_samples
|
| 188 |
+
)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"Error generating outputs: {e}")
|
| 191 |
+
continue
|
| 192 |
+
else:
|
| 193 |
+
# Expect outputs in the data
|
| 194 |
+
generations = data.get("outputs", [])
|
| 195 |
+
if len(generations) < 2:
|
| 196 |
+
print(f"Skipping example: need at least 2 outputs")
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
# Compute F1 scores for all generations
|
| 200 |
+
scores = []
|
| 201 |
+
for generation in generations:
|
| 202 |
+
metrics = compute_file_level_f1(generation, ground_truth)
|
| 203 |
+
scores.append(metrics["f1"])
|
| 204 |
+
|
| 205 |
+
# Create GRPO example
|
| 206 |
+
grpo_example = {
|
| 207 |
+
"prompt": prompt,
|
| 208 |
+
"completions": generations,
|
| 209 |
+
"scores": scores
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
f_out.write(json.dumps(grpo_example) + '\n')
|
| 213 |
+
examples_created += 1
|
| 214 |
+
examples_processed += 1
|
| 215 |
+
|
| 216 |
+
if examples_processed % 10 == 0:
|
| 217 |
+
print(f"Processed {examples_processed} examples")
|
| 218 |
+
print(f" Last example F1 scores: {[f'{s:.3f}' for s in scores]}")
|
| 219 |
+
|
| 220 |
+
print(f"\nDone! Created {examples_created} GRPO examples from {examples_processed} SFT examples")
|
| 221 |
+
print(f"Output saved to: {output_jsonl}")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def analyze_dataset(jsonl_path: str, dataset_type: str = "auto"):
|
| 225 |
+
"""
|
| 226 |
+
Analyze a dataset and print statistics.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
jsonl_path: Path to JSONL file
|
| 230 |
+
dataset_type: "sft", "dpo", "grpo", or "auto" (auto-detect)
|
| 231 |
+
"""
|
| 232 |
+
with open(jsonl_path, 'r') as f:
|
| 233 |
+
lines = f.readlines()
|
| 234 |
+
|
| 235 |
+
if not lines:
|
| 236 |
+
print("Empty dataset")
|
| 237 |
+
return
|
| 238 |
+
|
| 239 |
+
first = json.loads(lines[0])
|
| 240 |
+
|
| 241 |
+
# Auto-detect type
|
| 242 |
+
if dataset_type == "auto":
|
| 243 |
+
if "chosen" in first and "rejected" in first:
|
| 244 |
+
dataset_type = "dpo"
|
| 245 |
+
elif "completions" in first and "scores" in first:
|
| 246 |
+
dataset_type = "grpo"
|
| 247 |
+
else:
|
| 248 |
+
dataset_type = "sft"
|
| 249 |
+
|
| 250 |
+
print(f"\nDataset Analysis: {jsonl_path}")
|
| 251 |
+
print(f"Type: {dataset_type.upper()}")
|
| 252 |
+
print(f"Total examples: {len(lines)}")
|
| 253 |
+
|
| 254 |
+
if dataset_type == "dpo":
|
| 255 |
+
f1_diffs = []
|
| 256 |
+
for line in lines:
|
| 257 |
+
data = json.loads(line)
|
| 258 |
+
chosen_f1 = data.get("chosen_f1", 1.0)
|
| 259 |
+
rejected_f1 = data.get("rejected_f1", 0.0)
|
| 260 |
+
f1_diffs.append(chosen_f1 - rejected_f1)
|
| 261 |
+
|
| 262 |
+
print(f"Average F1 difference: {sum(f1_diffs) / len(f1_diffs):.3f}")
|
| 263 |
+
print(f"Min F1 difference: {min(f1_diffs):.3f}")
|
| 264 |
+
print(f"Max F1 difference: {max(f1_diffs):.3f}")
|
| 265 |
+
|
| 266 |
+
elif dataset_type == "grpo":
|
| 267 |
+
all_scores = []
|
| 268 |
+
completion_counts = []
|
| 269 |
+
for line in lines:
|
| 270 |
+
data = json.loads(line)
|
| 271 |
+
scores = data.get("scores", [])
|
| 272 |
+
all_scores.extend(scores)
|
| 273 |
+
completion_counts.append(len(scores))
|
| 274 |
+
|
| 275 |
+
print(f"Average completions per prompt: {sum(completion_counts) / len(completion_counts):.1f}")
|
| 276 |
+
print(f"Min completions: {min(completion_counts)}")
|
| 277 |
+
print(f"Max completions: {max(completion_counts)}")
|
| 278 |
+
print(f"Average F1 score: {sum(all_scores) / len(all_scores):.3f}")
|
| 279 |
+
print(f"F1 score range: [{min(all_scores):.3f}, {max(all_scores):.3f}]")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def main():
|
| 283 |
+
parser = argparse.ArgumentParser(description="Convert SFT data to DPO/GRPO formats")
|
| 284 |
+
parser.add_argument("--input", required=True, help="Input SFT JSONL file")
|
| 285 |
+
parser.add_argument("--output", required=True, help="Output JSONL file")
|
| 286 |
+
parser.add_argument("--format", choices=["dpo", "grpo"], required=True,
|
| 287 |
+
help="Output format")
|
| 288 |
+
parser.add_argument("--model", default=None,
|
| 289 |
+
help="Path to model for generation (optional)")
|
| 290 |
+
parser.add_argument("--num-samples", type=int, default=4,
|
| 291 |
+
help="Number of outputs to generate per prompt")
|
| 292 |
+
parser.add_argument("--max-examples", type=int, default=None,
|
| 293 |
+
help="Maximum number of examples to process")
|
| 294 |
+
parser.add_argument("--min-f1-diff", type=float, default=0.1,
|
| 295 |
+
help="Minimum F1 difference for DPO pairs")
|
| 296 |
+
parser.add_argument("--analyze", action="store_true",
|
| 297 |
+
help="Analyze the output dataset after creation")
|
| 298 |
+
|
| 299 |
+
args = parser.parse_args()
|
| 300 |
+
|
| 301 |
+
print(f"Converting {args.input} to {args.format.upper()} format...")
|
| 302 |
+
print(f"Output: {args.output}")
|
| 303 |
+
|
| 304 |
+
if args.format == "dpo":
|
| 305 |
+
convert_sft_to_dpo(
|
| 306 |
+
args.input,
|
| 307 |
+
args.output,
|
| 308 |
+
args.model,
|
| 309 |
+
args.num_samples,
|
| 310 |
+
args.min_f1_diff,
|
| 311 |
+
args.max_examples
|
| 312 |
+
)
|
| 313 |
+
elif args.format == "grpo":
|
| 314 |
+
convert_sft_to_grpo(
|
| 315 |
+
args.input,
|
| 316 |
+
args.output,
|
| 317 |
+
args.model,
|
| 318 |
+
args.num_samples,
|
| 319 |
+
args.max_examples
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if args.analyze:
|
| 323 |
+
analyze_dataset(args.output, args.format)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
# Example usage without CLI
|
| 328 |
+
import sys
|
| 329 |
+
|
| 330 |
+
if len(sys.argv) == 1:
|
| 331 |
+
print("Data Preparation Utilities")
|
| 332 |
+
print("=" * 50)
|
| 333 |
+
print("\nUsage:")
|
| 334 |
+
print(" python prepare_data.py --input instruct_data.jsonl --output dpo_data.jsonl --format dpo")
|
| 335 |
+
print(" python prepare_data.py --input instruct_data.jsonl --output grpo_data.jsonl --format grpo")
|
| 336 |
+
print("\nWith model generation:")
|
| 337 |
+
print(" python prepare_data.py --input instruct_data.jsonl --output dpo_data.jsonl --format dpo \\")
|
| 338 |
+
print(" --model ./runs/instruct_run_14b_v1/merged_14b_instruct_lora --num-samples 4")
|
| 339 |
+
print("\nAnalyze dataset:")
|
| 340 |
+
print(" python prepare_data.py --input dpo_data.jsonl --output /dev/null --format dpo --analyze")
|
| 341 |
+
sys.exit(0)
|
| 342 |
+
|
| 343 |
+
main()
|
DPO-14b/requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
transformers>=4.41.0
|
| 4 |
+
datasets>=2.18.0
|
| 5 |
+
accelerate>=0.30.0
|
| 6 |
+
|
| 7 |
+
# PEFT / QLoRA
|
| 8 |
+
peft>=0.11.1
|
| 9 |
+
bitsandbytes>=0.43.1
|
| 10 |
+
|
| 11 |
+
# TRL for DPO
|
| 12 |
+
trl>=0.8.0
|
| 13 |
+
|
| 14 |
+
# Hugging Face Hub
|
| 15 |
+
huggingface_hub>=0.23.0
|
| 16 |
+
|
| 17 |
+
# Config + utilities
|
| 18 |
+
pyyaml>=6.0
|
| 19 |
+
tqdm>=4.66.0
|
| 20 |
+
|
| 21 |
+
# Tokenizers and safetensors
|
| 22 |
+
tokenizers>=0.15.0
|
| 23 |
+
safetensors>=0.4.2
|
| 24 |
+
|
| 25 |
+
# Experiment tracking
|
| 26 |
+
wandb>=0.16.0
|
| 27 |
+
|
| 28 |
+
# For F1 score computation
|
| 29 |
+
scikit-learn>=1.3.0
|
DPO-14b/run_dpo.py
ADDED
|
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import inspect
|
| 4 |
+
import math
|
| 5 |
+
import gc
|
| 6 |
+
import time
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import yaml
|
| 13 |
+
from datasets import load_dataset, DatasetDict
|
| 14 |
+
from huggingface_hub import snapshot_download
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoTokenizer,
|
| 17 |
+
AutoModelForCausalLM,
|
| 18 |
+
BitsAndBytesConfig,
|
| 19 |
+
TrainingArguments,
|
| 20 |
+
TrainerCallback,
|
| 21 |
+
EarlyStoppingCallback,
|
| 22 |
+
set_seed,
|
| 23 |
+
)
|
| 24 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 25 |
+
from peft import (
|
| 26 |
+
LoraConfig,
|
| 27 |
+
get_peft_model,
|
| 28 |
+
prepare_model_for_kbit_training,
|
| 29 |
+
PeftModel,
|
| 30 |
+
)
|
| 31 |
+
from trl import DPOTrainer, DPOConfig
|
| 32 |
+
|
| 33 |
+
# Version check for TRL
|
| 34 |
+
try:
|
| 35 |
+
from packaging import version
|
| 36 |
+
import trl
|
| 37 |
+
if version.parse(trl.__version__) < version.parse("0.7.0"):
|
| 38 |
+
logger.warning(f"TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
|
| 39 |
+
except ImportError:
|
| 40 |
+
logger.warning("Could not verify TRL version")
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import wandb
|
| 44 |
+
WANDB_AVAILABLE = True
|
| 45 |
+
except ImportError:
|
| 46 |
+
WANDB_AVAILABLE = False
|
| 47 |
+
wandb = None
|
| 48 |
+
|
| 49 |
+
# Setup logging
|
| 50 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 51 |
+
logger = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# --------------------------
|
| 55 |
+
# Custom Exceptions
|
| 56 |
+
# --------------------------
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class DataFormattingError(Exception):
|
| 60 |
+
"""Exception raised for errors in data formatting."""
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DataValidationError(Exception):
|
| 65 |
+
"""Exception raised for errors in data validation."""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# --------------------------
|
| 70 |
+
# Helpers
|
| 71 |
+
# --------------------------
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 75 |
+
s = (s or "").lower()
|
| 76 |
+
if s in ("float16", "fp16"):
|
| 77 |
+
return torch.float16
|
| 78 |
+
if s in ("bfloat16", "bf16"):
|
| 79 |
+
return torch.bfloat16
|
| 80 |
+
if s in ("float32", "fp32"):
|
| 81 |
+
return torch.float32
|
| 82 |
+
raise ValueError(f"Unknown torch_dtype: {s}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _now_iso() -> str:
|
| 86 |
+
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _safe_exp(x: float) -> float:
|
| 90 |
+
x = min(float(x), 50.0)
|
| 91 |
+
return float(math.exp(x))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _ensure_dir(p: Path) -> Path:
|
| 95 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 96 |
+
return p
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _looks_like_model_dir(p: Path) -> bool:
|
| 100 |
+
if not p.exists() or not p.is_dir():
|
| 101 |
+
return False
|
| 102 |
+
if (p / "config.json").exists():
|
| 103 |
+
return True
|
| 104 |
+
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
|
| 105 |
+
return True
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _infer_target_modules(model) -> List[str]:
|
| 110 |
+
names = set()
|
| 111 |
+
for n, _ in model.named_modules():
|
| 112 |
+
names.add(n.split(".")[-1])
|
| 113 |
+
|
| 114 |
+
for group in [
|
| 115 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 116 |
+
["Wqkv", "out_proj"],
|
| 117 |
+
["query_key_value", "dense"],
|
| 118 |
+
["c_attn", "c_proj"],
|
| 119 |
+
]:
|
| 120 |
+
if all(x in names for x in group):
|
| 121 |
+
return group
|
| 122 |
+
|
| 123 |
+
fallback = [
|
| 124 |
+
x
|
| 125 |
+
for x in [
|
| 126 |
+
"q_proj",
|
| 127 |
+
"k_proj",
|
| 128 |
+
"v_proj",
|
| 129 |
+
"o_proj",
|
| 130 |
+
"c_attn",
|
| 131 |
+
"c_proj",
|
| 132 |
+
"out_proj",
|
| 133 |
+
"dense",
|
| 134 |
+
]
|
| 135 |
+
if x in names
|
| 136 |
+
]
|
| 137 |
+
if fallback:
|
| 138 |
+
return fallback
|
| 139 |
+
|
| 140 |
+
raise ValueError(
|
| 141 |
+
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 146 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# --------------------------
|
| 150 |
+
# Wandb Integration
|
| 151 |
+
# --------------------------
|
| 152 |
+
|
| 153 |
+
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
|
| 154 |
+
"""Initialize Wandb if enabled in configuration."""
|
| 155 |
+
wandb_cfg = cfg.get("wandb", {})
|
| 156 |
+
|
| 157 |
+
if not wandb_cfg.get("enabled", False):
|
| 158 |
+
print("Wandb logging disabled")
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
if not WANDB_AVAILABLE:
|
| 162 |
+
print("Wandb not available. Install with: pip install wandb")
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
project = wandb_cfg.get("project", "dpo-training")
|
| 166 |
+
entity = wandb_cfg.get("entity", None)
|
| 167 |
+
name = wandb_cfg.get("name", None)
|
| 168 |
+
tags = wandb_cfg.get("tags", [])
|
| 169 |
+
notes = wandb_cfg.get("notes", None)
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
wandb.init(
|
| 173 |
+
project=project,
|
| 174 |
+
entity=entity,
|
| 175 |
+
name=name,
|
| 176 |
+
tags=tags,
|
| 177 |
+
notes=notes,
|
| 178 |
+
dir=str(run_dir),
|
| 179 |
+
config={
|
| 180 |
+
"model": cfg.get("model", {}),
|
| 181 |
+
"data": cfg.get("data", {}),
|
| 182 |
+
"peft": cfg.get("peft", {}),
|
| 183 |
+
"dpo": cfg.get("dpo", {}),
|
| 184 |
+
"train": cfg.get("train", {}),
|
| 185 |
+
"run_dir": str(run_dir),
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
|
| 189 |
+
return wandb
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Failed to initialize Wandb: {e}")
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def finish_wandb():
|
| 196 |
+
"""Finish Wandb run if active."""
|
| 197 |
+
if WANDB_AVAILABLE and wandb.run is not None:
|
| 198 |
+
wandb.finish()
|
| 199 |
+
print("Wandb run finished")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# --------------------------
|
| 203 |
+
# JSONL Logger Callback
|
| 204 |
+
# --------------------------
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 208 |
+
def __init__(self, run_dir: Path):
|
| 209 |
+
self.run_dir = run_dir
|
| 210 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 211 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 212 |
+
self.start_time = None
|
| 213 |
+
|
| 214 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 215 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 216 |
+
return None
|
| 217 |
+
elapsed = time.time() - self.start_time
|
| 218 |
+
sec_per_step = elapsed / global_step
|
| 219 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 220 |
+
h = int(remaining // 3600)
|
| 221 |
+
m = int((remaining % 3600) // 60)
|
| 222 |
+
s = int(remaining % 60)
|
| 223 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 224 |
+
|
| 225 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 226 |
+
self.start_time = time.time()
|
| 227 |
+
|
| 228 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 229 |
+
if not logs:
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 233 |
+
progress_pct = (
|
| 234 |
+
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 235 |
+
)
|
| 236 |
+
epoch_pct = None
|
| 237 |
+
if (
|
| 238 |
+
state.epoch is not None
|
| 239 |
+
and args.num_train_epochs
|
| 240 |
+
and args.num_train_epochs > 0
|
| 241 |
+
):
|
| 242 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 243 |
+
|
| 244 |
+
payload = {
|
| 245 |
+
"ts": _now_iso(),
|
| 246 |
+
"event": "train_log",
|
| 247 |
+
"step": int(state.global_step),
|
| 248 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 249 |
+
"progress_pct": (
|
| 250 |
+
round(progress_pct, 2) if progress_pct is not None else None
|
| 251 |
+
),
|
| 252 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 253 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 254 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 255 |
+
**logs,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 259 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 260 |
+
|
| 261 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 262 |
+
if not metrics:
|
| 263 |
+
return
|
| 264 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 265 |
+
|
| 266 |
+
payload = {
|
| 267 |
+
"ts": _now_iso(),
|
| 268 |
+
"event": "eval",
|
| 269 |
+
"step": int(state.global_step),
|
| 270 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 271 |
+
**metrics,
|
| 272 |
+
}
|
| 273 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 274 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# --------------------------
|
| 278 |
+
# Custom Exceptions
|
| 279 |
+
# --------------------------
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class DataFormattingError(Exception):
|
| 283 |
+
"""Exception raised for errors in data formatting."""
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class DataValidationError(Exception):
|
| 288 |
+
"""Exception raised for errors in data validation."""
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# --------------------------
|
| 293 |
+
# Data Pipeline (DPO Format)
|
| 294 |
+
# --------------------------
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def format_dpo_example(
|
| 298 |
+
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
|
| 299 |
+
) -> Dict[str, Any]:
|
| 300 |
+
"""
|
| 301 |
+
Format DPO data which requires prompt, chosen, and rejected completions.
|
| 302 |
+
Returns formatted prompt, chosen, and rejected texts.
|
| 303 |
+
Raises DataFormattingError if formatting fails.
|
| 304 |
+
"""
|
| 305 |
+
data_cfg = cfg["data"]
|
| 306 |
+
format_type = data_cfg.get("format_type", "chatml")
|
| 307 |
+
|
| 308 |
+
# Get field names from config
|
| 309 |
+
prompt_field = data_cfg.get("prompt_field", "prompt")
|
| 310 |
+
chosen_field = data_cfg.get("chosen_field", "chosen")
|
| 311 |
+
rejected_field = data_cfg.get("rejected_field", "rejected")
|
| 312 |
+
|
| 313 |
+
# Extract text from example
|
| 314 |
+
prompt = example.get(prompt_field, "")
|
| 315 |
+
chosen = example.get(chosen_field, "")
|
| 316 |
+
rejected = example.get(rejected_field, "")
|
| 317 |
+
|
| 318 |
+
# Validate required fields
|
| 319 |
+
if not prompt:
|
| 320 |
+
raise DataFormattingError(f"Empty prompt field: {prompt_field}")
|
| 321 |
+
if not chosen:
|
| 322 |
+
raise DataFormattingError(f"Empty chosen field: {chosen_field}")
|
| 323 |
+
if not rejected:
|
| 324 |
+
raise DataFormattingError(f"Empty rejected field: {rejected_field}")
|
| 325 |
+
|
| 326 |
+
if format_type == "chatml":
|
| 327 |
+
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
|
| 328 |
+
|
| 329 |
+
# Format prompt with system message
|
| 330 |
+
messages = []
|
| 331 |
+
if system_prompt:
|
| 332 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 333 |
+
messages.append({"role": "user", "content": prompt})
|
| 334 |
+
|
| 335 |
+
# Apply chat template for prompt only (without assistant response)
|
| 336 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 337 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Chosen and rejected are just the completions (will be added by DPOTrainer)
|
| 341 |
+
formatted_chosen = chosen
|
| 342 |
+
formatted_rejected = rejected
|
| 343 |
+
|
| 344 |
+
# Add EOS token to completions
|
| 345 |
+
if tokenizer.eos_token:
|
| 346 |
+
if not formatted_chosen.endswith(tokenizer.eos_token):
|
| 347 |
+
formatted_chosen += tokenizer.eos_token
|
| 348 |
+
if not formatted_rejected.endswith(tokenizer.eos_token):
|
| 349 |
+
formatted_rejected += tokenizer.eos_token
|
| 350 |
+
|
| 351 |
+
elif format_type == "alpaca":
|
| 352 |
+
# Alpaca format
|
| 353 |
+
prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
|
| 354 |
+
formatted_prompt = prefix
|
| 355 |
+
formatted_chosen = chosen
|
| 356 |
+
formatted_rejected = rejected
|
| 357 |
+
|
| 358 |
+
if tokenizer.eos_token:
|
| 359 |
+
if not formatted_chosen.endswith(tokenizer.eos_token):
|
| 360 |
+
formatted_chosen += tokenizer.eos_token
|
| 361 |
+
if not formatted_rejected.endswith(tokenizer.eos_token):
|
| 362 |
+
formatted_rejected += tokenizer.eos_token
|
| 363 |
+
|
| 364 |
+
elif format_type == "custom":
|
| 365 |
+
# Custom template
|
| 366 |
+
template = data_cfg.get("custom_template", "{prompt}")
|
| 367 |
+
formatted_prompt = template.format(prompt=prompt)
|
| 368 |
+
formatted_chosen = chosen
|
| 369 |
+
formatted_rejected = rejected
|
| 370 |
+
|
| 371 |
+
if tokenizer.eos_token:
|
| 372 |
+
if not formatted_chosen.endswith(tokenizer.eos_token):
|
| 373 |
+
formatted_chosen += tokenizer.eos_token
|
| 374 |
+
if not formatted_rejected.endswith(tokenizer.eos_token):
|
| 375 |
+
formatted_rejected += tokenizer.eos_token
|
| 376 |
+
else:
|
| 377 |
+
raise ValueError(f"Unsupported format_type: {format_type}")
|
| 378 |
+
|
| 379 |
+
return {
|
| 380 |
+
"prompt": formatted_prompt,
|
| 381 |
+
"chosen": formatted_chosen,
|
| 382 |
+
"rejected": formatted_rejected,
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def validate_dpo_data(dataset, stage: str = "train") -> None:
|
| 387 |
+
"""
|
| 388 |
+
Validate DPO dataset has all required fields and proper structure.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
dataset: Dataset to validate
|
| 392 |
+
stage: Training stage ("train" or "eval")
|
| 393 |
+
|
| 394 |
+
Raises:
|
| 395 |
+
DataValidationError if validation fails
|
| 396 |
+
"""
|
| 397 |
+
required_fields = ["prompt", "chosen", "rejected"]
|
| 398 |
+
|
| 399 |
+
# Check required fields exist
|
| 400 |
+
for field in required_fields:
|
| 401 |
+
if field not in dataset.column_names:
|
| 402 |
+
raise DataValidationError(
|
| 403 |
+
f"{stage} dataset missing required field: {field}. "
|
| 404 |
+
f"Available fields: {dataset.column_names}"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
# Sample validation - check first example
|
| 408 |
+
if len(dataset) > 0:
|
| 409 |
+
sample = dataset[0]
|
| 410 |
+
for field in required_fields:
|
| 411 |
+
if not sample[field] or len(sample[field].strip()) == 0:
|
| 412 |
+
logger.warning(f"{stage} dataset has empty {field} in first example")
|
| 413 |
+
|
| 414 |
+
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def build_dpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
|
| 418 |
+
"""
|
| 419 |
+
Build datasets for DPO training.
|
| 420 |
+
Expected JSONL format: {"prompt": "...", "chosen": "...", "rejected": "..."}
|
| 421 |
+
Or with custom field names specified in config.
|
| 422 |
+
"""
|
| 423 |
+
data_cfg = cfg["data"]
|
| 424 |
+
train_path = data_cfg["train_jsonl"]
|
| 425 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 426 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 427 |
+
shuffle = bool(data_cfg.get("shuffle", True))
|
| 428 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 429 |
+
|
| 430 |
+
# Ensure tokenizer has pad token
|
| 431 |
+
if tokenizer.pad_token is None:
|
| 432 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 433 |
+
|
| 434 |
+
# Load datasets
|
| 435 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 436 |
+
|
| 437 |
+
if eval_path:
|
| 438 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 439 |
+
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
|
| 440 |
+
else:
|
| 441 |
+
if 0.0 < split_ratio < 1.0:
|
| 442 |
+
split = ds["train"].train_test_split(
|
| 443 |
+
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
|
| 444 |
+
)
|
| 445 |
+
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
|
| 446 |
+
else:
|
| 447 |
+
dsd = DatasetDict({"train": ds["train"], "eval": None})
|
| 448 |
+
|
| 449 |
+
# Format DPO examples with error handling
|
| 450 |
+
def format_fn(examples):
|
| 451 |
+
prompts = []
|
| 452 |
+
chosen_list = []
|
| 453 |
+
rejected_list = []
|
| 454 |
+
errors = 0
|
| 455 |
+
|
| 456 |
+
for i in range(len(examples[list(examples.keys())[0]])):
|
| 457 |
+
example = {k: examples[k][i] for k in examples.keys()}
|
| 458 |
+
try:
|
| 459 |
+
formatted = format_dpo_example(example, cfg, tokenizer)
|
| 460 |
+
prompts.append(formatted["prompt"])
|
| 461 |
+
chosen_list.append(formatted["chosen"])
|
| 462 |
+
rejected_list.append(formatted["rejected"])
|
| 463 |
+
except (DataFormattingError, Exception) as e:
|
| 464 |
+
errors += 1
|
| 465 |
+
if errors <= 5: # Log first 5 errors
|
| 466 |
+
logger.warning(f"Failed to format example {i}: {e}")
|
| 467 |
+
# Add empty placeholder to maintain batch structure
|
| 468 |
+
prompts.append("")
|
| 469 |
+
chosen_list.append("")
|
| 470 |
+
rejected_list.append("")
|
| 471 |
+
|
| 472 |
+
if errors > 0:
|
| 473 |
+
logger.warning(f"Total formatting errors in batch: {errors}")
|
| 474 |
+
|
| 475 |
+
return {
|
| 476 |
+
"prompt": prompts,
|
| 477 |
+
"chosen": chosen_list,
|
| 478 |
+
"rejected": rejected_list,
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
logger.info("Formatting train DPO data...")
|
| 482 |
+
formatted_train = dsd["train"].map(
|
| 483 |
+
format_fn,
|
| 484 |
+
batched=True,
|
| 485 |
+
num_proc=num_proc,
|
| 486 |
+
remove_columns=dsd["train"].column_names,
|
| 487 |
+
desc="Formatting train DPO data",
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# Filter out failed examples (empty prompts)
|
| 491 |
+
formatted_train = formatted_train.filter(lambda x: len(x["prompt"]) > 0)
|
| 492 |
+
logger.info(f"Train dataset after filtering: {len(formatted_train)} examples")
|
| 493 |
+
|
| 494 |
+
# Validate formatted data
|
| 495 |
+
validate_dpo_data(formatted_train, "train")
|
| 496 |
+
|
| 497 |
+
formatted_eval = None
|
| 498 |
+
if dsd["eval"] is not None:
|
| 499 |
+
logger.info("Formatting eval DPO data...")
|
| 500 |
+
formatted_eval = dsd["eval"].map(
|
| 501 |
+
format_fn,
|
| 502 |
+
batched=True,
|
| 503 |
+
num_proc=num_proc,
|
| 504 |
+
remove_columns=dsd["eval"].column_names,
|
| 505 |
+
desc="Formatting eval DPO data",
|
| 506 |
+
)
|
| 507 |
+
formatted_eval = formatted_eval.filter(lambda x: len(x["prompt"]) > 0)
|
| 508 |
+
logger.info(f"Eval dataset after filtering: {len(formatted_eval)} examples")
|
| 509 |
+
validate_dpo_data(formatted_eval, "eval")
|
| 510 |
+
|
| 511 |
+
if shuffle:
|
| 512 |
+
formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 513 |
+
|
| 514 |
+
return formatted_train, formatted_eval
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
# --------------------------
|
| 518 |
+
# Model Loading + PEFT
|
| 519 |
+
# --------------------------
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
|
| 523 |
+
model_cfg = cfg["model"]
|
| 524 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 525 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 526 |
+
device_map = model_cfg.get("device_map", "auto")
|
| 527 |
+
|
| 528 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 529 |
+
str(base_dir),
|
| 530 |
+
use_fast=use_fast,
|
| 531 |
+
trust_remote_code=trust_remote_code,
|
| 532 |
+
)
|
| 533 |
+
if tokenizer.pad_token is None:
|
| 534 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 535 |
+
|
| 536 |
+
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 537 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 538 |
+
|
| 539 |
+
quant_cfg = None
|
| 540 |
+
if use_4bit:
|
| 541 |
+
quant_cfg = BitsAndBytesConfig(
|
| 542 |
+
load_in_4bit=True,
|
| 543 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
|
| 544 |
+
bnb_4bit_use_double_quant=bool(
|
| 545 |
+
model_cfg.get("bnb_4bit_use_double_quant", True)
|
| 546 |
+
),
|
| 547 |
+
bnb_4bit_compute_dtype=_dtype_from_str(
|
| 548 |
+
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
|
| 549 |
+
),
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
attn_impl = _choose_attn_impl(cfg)
|
| 553 |
+
|
| 554 |
+
try:
|
| 555 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 556 |
+
str(base_dir),
|
| 557 |
+
device_map=device_map,
|
| 558 |
+
trust_remote_code=trust_remote_code,
|
| 559 |
+
low_cpu_mem_usage=True,
|
| 560 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 561 |
+
quantization_config=quant_cfg,
|
| 562 |
+
attn_implementation=attn_impl,
|
| 563 |
+
)
|
| 564 |
+
except Exception as e:
|
| 565 |
+
if attn_impl is not None:
|
| 566 |
+
logger.warning(f"attn_implementation='{attn_impl}' failed: {e}")
|
| 567 |
+
logger.warning("Falling back to default attention implementation.")
|
| 568 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 569 |
+
str(base_dir),
|
| 570 |
+
device_map=device_map,
|
| 571 |
+
trust_remote_code=trust_remote_code,
|
| 572 |
+
low_cpu_mem_usage=True,
|
| 573 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 574 |
+
quantization_config=quant_cfg,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
return model, tokenizer
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def apply_peft(cfg: Dict[str, Any], model):
|
| 581 |
+
peft_cfg = cfg["peft"]
|
| 582 |
+
model_cfg = cfg["model"]
|
| 583 |
+
tr_cfg = cfg["train"]
|
| 584 |
+
|
| 585 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 586 |
+
return model, None
|
| 587 |
+
|
| 588 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 589 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 590 |
+
|
| 591 |
+
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
|
| 592 |
+
model.gradient_checkpointing_enable()
|
| 593 |
+
if hasattr(model, "config"):
|
| 594 |
+
model.config.use_cache = False
|
| 595 |
+
|
| 596 |
+
if use_4bit:
|
| 597 |
+
model = prepare_model_for_kbit_training(
|
| 598 |
+
model,
|
| 599 |
+
use_gradient_checkpointing=gradient_checkpointing,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 603 |
+
if target_modules == "auto":
|
| 604 |
+
target_modules = _infer_target_modules(model)
|
| 605 |
+
|
| 606 |
+
lora_config = LoraConfig(
|
| 607 |
+
r=int(peft_cfg.get("r", 16)),
|
| 608 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 609 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 610 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 611 |
+
task_type="CAUSAL_LM",
|
| 612 |
+
target_modules=target_modules,
|
| 613 |
+
)
|
| 614 |
+
model = get_peft_model(model, lora_config)
|
| 615 |
+
return model, lora_config
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
# --------------------------
|
| 619 |
+
# Merge Logic
|
| 620 |
+
# --------------------------
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def merge_adapter(
|
| 624 |
+
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
|
| 625 |
+
):
|
| 626 |
+
logger.info(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
|
| 627 |
+
|
| 628 |
+
model_cfg = cfg["model"]
|
| 629 |
+
merge_cfg = cfg.get("merge", {})
|
| 630 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 631 |
+
|
| 632 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 633 |
+
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
|
| 634 |
+
|
| 635 |
+
try:
|
| 636 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 637 |
+
str(base_dir),
|
| 638 |
+
torch_dtype=merged_dtype,
|
| 639 |
+
device_map="cpu",
|
| 640 |
+
low_cpu_mem_usage=True,
|
| 641 |
+
trust_remote_code=trust_remote_code,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
merged = PeftModel.from_pretrained(base, str(adapter_dir))
|
| 645 |
+
merged = merged.merge_and_unload()
|
| 646 |
+
|
| 647 |
+
# Clean up base model to free memory
|
| 648 |
+
del base
|
| 649 |
+
gc.collect()
|
| 650 |
+
torch.cuda.empty_cache()
|
| 651 |
+
|
| 652 |
+
_ensure_dir(final_dir)
|
| 653 |
+
merged.save_pretrained(
|
| 654 |
+
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# Clean up merged model
|
| 658 |
+
del merged
|
| 659 |
+
gc.collect()
|
| 660 |
+
torch.cuda.empty_cache()
|
| 661 |
+
|
| 662 |
+
tok = AutoTokenizer.from_pretrained(
|
| 663 |
+
str(base_dir), trust_remote_code=trust_remote_code
|
| 664 |
+
)
|
| 665 |
+
if tok.pad_token is None:
|
| 666 |
+
tok.pad_token = tok.eos_token
|
| 667 |
+
tok.save_pretrained(str(final_dir))
|
| 668 |
+
|
| 669 |
+
logger.info("--- Merge complete ---")
|
| 670 |
+
except Exception as e:
|
| 671 |
+
logger.error(f"Merge failed: {e}")
|
| 672 |
+
raise
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
# --------------------------
|
| 676 |
+
# Main
|
| 677 |
+
# --------------------------
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def main():
|
| 681 |
+
ap = argparse.ArgumentParser()
|
| 682 |
+
ap.add_argument("--config", required=True, help="Path to YAML config")
|
| 683 |
+
ap.add_argument(
|
| 684 |
+
"--merge-only", action="store_true", help="Skip training, just merge adapter"
|
| 685 |
+
)
|
| 686 |
+
args = ap.parse_args()
|
| 687 |
+
|
| 688 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 689 |
+
cfg = yaml.safe_load(f)
|
| 690 |
+
|
| 691 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 692 |
+
_ensure_dir(run_dir / "logs")
|
| 693 |
+
|
| 694 |
+
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
|
| 695 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 696 |
+
|
| 697 |
+
model_cfg = cfg["model"]
|
| 698 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 699 |
+
repo_path = Path(repo_id)
|
| 700 |
+
|
| 701 |
+
# Local model path -> load directly
|
| 702 |
+
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
|
| 703 |
+
base_dir = repo_path
|
| 704 |
+
logger.info(f"Using local model at: {base_dir}")
|
| 705 |
+
elif repo_path.exists() and repo_path.is_dir():
|
| 706 |
+
raise ValueError(
|
| 707 |
+
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
|
| 708 |
+
)
|
| 709 |
+
else:
|
| 710 |
+
# HF repo_id -> download
|
| 711 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 712 |
+
if not _looks_like_model_dir(base_dir):
|
| 713 |
+
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
|
| 714 |
+
snapshot_download(
|
| 715 |
+
repo_id=repo_id,
|
| 716 |
+
revision=model_cfg.get("revision", None),
|
| 717 |
+
local_dir=str(base_dir),
|
| 718 |
+
local_dir_use_symlinks=False,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
|
| 722 |
+
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
|
| 723 |
+
|
| 724 |
+
merge_cfg = cfg.get("merge", {}) or {}
|
| 725 |
+
if merge_cfg.get("output_dir"):
|
| 726 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 727 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 728 |
+
else:
|
| 729 |
+
final_dir = run_dir / "final_model"
|
| 730 |
+
|
| 731 |
+
# Merge-only
|
| 732 |
+
if args.merge_only:
|
| 733 |
+
if not _looks_like_model_dir(best_adapter_dir):
|
| 734 |
+
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
|
| 735 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 736 |
+
return
|
| 737 |
+
|
| 738 |
+
# Initialize Wandb
|
| 739 |
+
wandb_run = setup_wandb(cfg, run_dir)
|
| 740 |
+
|
| 741 |
+
# Training
|
| 742 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 743 |
+
|
| 744 |
+
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
|
| 745 |
+
model, _ = apply_peft(cfg, model)
|
| 746 |
+
|
| 747 |
+
# Load reference model for DPO (if using reference model)
|
| 748 |
+
dpo_cfg = cfg.get("dpo", {})
|
| 749 |
+
use_reference_model = bool(dpo_cfg.get("use_reference_model", True))
|
| 750 |
+
reference_free = bool(dpo_cfg.get("reference_free", False))
|
| 751 |
+
|
| 752 |
+
ref_model = None
|
| 753 |
+
if use_reference_model and not reference_free:
|
| 754 |
+
print("Loading reference model (frozen copy)...")
|
| 755 |
+
ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir)
|
| 756 |
+
ref_model, _ = apply_peft(cfg, ref_model)
|
| 757 |
+
# Freeze reference model
|
| 758 |
+
for param in ref_model.parameters():
|
| 759 |
+
param.requires_grad = False
|
| 760 |
+
ref_model.eval()
|
| 761 |
+
print("Reference model loaded and frozen")
|
| 762 |
+
|
| 763 |
+
train_ds, eval_ds = build_dpo_datasets(cfg, tokenizer)
|
| 764 |
+
|
| 765 |
+
tr_cfg = cfg["train"]
|
| 766 |
+
|
| 767 |
+
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 768 |
+
use_fp16 = dtype == torch.float16
|
| 769 |
+
use_bf16 = dtype == torch.bfloat16
|
| 770 |
+
|
| 771 |
+
max_steps = int(tr_cfg.get("max_steps", 0))
|
| 772 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 773 |
+
|
| 774 |
+
# Dynamic evaluation strategy parameter handling
|
| 775 |
+
ta_params = inspect.signature(TrainingArguments.__init__).parameters
|
| 776 |
+
eval_key = (
|
| 777 |
+
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
# Setup reporting based on wandb availability
|
| 781 |
+
report_to = []
|
| 782 |
+
if wandb_run is not None:
|
| 783 |
+
report_to.append("wandb")
|
| 784 |
+
|
| 785 |
+
# Validate and adjust training parameters
|
| 786 |
+
max_grad_norm = float(tr_cfg.get("max_grad_norm", 1.0))
|
| 787 |
+
if max_grad_norm <= 0:
|
| 788 |
+
logger.warning(f"Invalid max_grad_norm={max_grad_norm}, using 1.0")
|
| 789 |
+
max_grad_norm = 1.0
|
| 790 |
+
|
| 791 |
+
ta_kwargs = dict(
|
| 792 |
+
output_dir=str(ckpt_dir),
|
| 793 |
+
max_steps=max_steps if max_steps > 0 else -1,
|
| 794 |
+
num_train_epochs=num_train_epochs,
|
| 795 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
|
| 796 |
+
per_device_eval_batch_size=int(
|
| 797 |
+
tr_cfg.get(
|
| 798 |
+
"per_device_eval_batch_size",
|
| 799 |
+
tr_cfg.get("per_device_train_batch_size", 1),
|
| 800 |
+
)
|
| 801 |
+
),
|
| 802 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
|
| 803 |
+
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)),
|
| 804 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
|
| 805 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 806 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
|
| 807 |
+
optim=str(
|
| 808 |
+
tr_cfg.get(
|
| 809 |
+
"optim",
|
| 810 |
+
(
|
| 811 |
+
"paged_adamw_8bit"
|
| 812 |
+
if bool(model_cfg.get("use_4bit", False))
|
| 813 |
+
else "adamw_torch"
|
| 814 |
+
),
|
| 815 |
+
)
|
| 816 |
+
),
|
| 817 |
+
max_grad_norm=max_grad_norm,
|
| 818 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 819 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 820 |
+
save_steps=int(tr_cfg.get("save_steps", 200)),
|
| 821 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 822 |
+
eval_steps=int(tr_cfg.get("eval_steps", 50)),
|
| 823 |
+
load_best_model_at_end=(
|
| 824 |
+
bool(tr_cfg.get("load_best_model_at_end", True))
|
| 825 |
+
if eval_ds is not None
|
| 826 |
+
else False
|
| 827 |
+
),
|
| 828 |
+
metric_for_best_model="eval_loss",
|
| 829 |
+
greater_is_better=False,
|
| 830 |
+
fp16=use_fp16,
|
| 831 |
+
bf16=use_bf16,
|
| 832 |
+
report_to=report_to,
|
| 833 |
+
remove_unused_columns=False,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
# Set the correct argument name for this transformers version
|
| 837 |
+
ta_kwargs[eval_key] = str(
|
| 838 |
+
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
training_args = TrainingArguments(**ta_kwargs)
|
| 842 |
+
|
| 843 |
+
# Setup callbacks
|
| 844 |
+
callbacks = [JsonlLoggerCallback(run_dir)]
|
| 845 |
+
|
| 846 |
+
# Add early stopping callback if enabled
|
| 847 |
+
early_stopping_cfg = tr_cfg.get("early_stopping", {})
|
| 848 |
+
if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
|
| 849 |
+
early_stopping_callback = EarlyStoppingCallback(
|
| 850 |
+
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
|
| 851 |
+
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
|
| 852 |
+
)
|
| 853 |
+
callbacks.append(early_stopping_callback)
|
| 854 |
+
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
|
| 855 |
+
f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
|
| 856 |
+
|
| 857 |
+
# DPO-specific parameters
|
| 858 |
+
beta = float(dpo_cfg.get("beta", 0.1))
|
| 859 |
+
label_smoothing = float(dpo_cfg.get("label_smoothing", 0.0))
|
| 860 |
+
loss_type = str(dpo_cfg.get("loss_type", "sigmoid"))
|
| 861 |
+
max_length = int(cfg["data"].get("max_length", 2048))
|
| 862 |
+
max_prompt_length = int(cfg["data"].get("max_prompt_length", max_length // 2))
|
| 863 |
+
|
| 864 |
+
logger.info(f"DPO Training with beta={beta}, loss_type={loss_type}")
|
| 865 |
+
|
| 866 |
+
# Get evaluation strategy from config
|
| 867 |
+
eval_strategy_val = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
|
| 868 |
+
|
| 869 |
+
# Create DPOConfig with all training and DPO-specific parameters
|
| 870 |
+
dpo_config = DPOConfig(
|
| 871 |
+
output_dir=str(run_dir),
|
| 872 |
+
num_train_epochs=int(tr_cfg.get("num_train_epochs", 3)),
|
| 873 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 2)),
|
| 874 |
+
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", 4)),
|
| 875 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 4)),
|
| 876 |
+
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)),
|
| 877 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.01)),
|
| 878 |
+
adam_beta1=float(tr_cfg.get("adam_beta1", 0.9)),
|
| 879 |
+
adam_beta2=float(tr_cfg.get("adam_beta2", 0.999)),
|
| 880 |
+
adam_epsilon=float(tr_cfg.get("adam_epsilon", 1e-8)),
|
| 881 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
|
| 882 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "linear")),
|
| 883 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 884 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 885 |
+
save_steps=int(tr_cfg.get("save_steps", 100)),
|
| 886 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 887 |
+
eval_steps=int(tr_cfg.get("eval_steps", 100)) if eval_ds is not None else None,
|
| 888 |
+
eval_strategy=eval_strategy_val,
|
| 889 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 890 |
+
load_best_model_at_end=(
|
| 891 |
+
bool(tr_cfg.get("load_best_model_at_end", False))
|
| 892 |
+
if eval_ds is not None
|
| 893 |
+
else False
|
| 894 |
+
),
|
| 895 |
+
metric_for_best_model=str(tr_cfg.get("metric_for_best_model", "eval_loss")),
|
| 896 |
+
greater_is_better=bool(tr_cfg.get("greater_is_better", False)),
|
| 897 |
+
fp16=use_fp16,
|
| 898 |
+
bf16=use_bf16,
|
| 899 |
+
report_to=report_to,
|
| 900 |
+
remove_unused_columns=False,
|
| 901 |
+
# DPO-specific parameters
|
| 902 |
+
beta=beta,
|
| 903 |
+
label_smoothing=label_smoothing,
|
| 904 |
+
loss_type=loss_type,
|
| 905 |
+
max_length=max_length,
|
| 906 |
+
max_prompt_length=max_prompt_length,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
trainer = DPOTrainer(
|
| 910 |
+
model=model,
|
| 911 |
+
ref_model=ref_model,
|
| 912 |
+
args=dpo_config,
|
| 913 |
+
train_dataset=train_ds,
|
| 914 |
+
eval_dataset=eval_ds,
|
| 915 |
+
processing_class=tokenizer,
|
| 916 |
+
callbacks=callbacks,
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
# Resume
|
| 920 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 921 |
+
if resume_from == "auto":
|
| 922 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 923 |
+
resume_from = last if last else None
|
| 924 |
+
if resume_from:
|
| 925 |
+
logger.info(f"Resuming from {resume_from}")
|
| 926 |
+
|
| 927 |
+
logger.info("Starting DPO training...")
|
| 928 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 929 |
+
|
| 930 |
+
trainer.save_model(str(best_adapter_dir))
|
| 931 |
+
logger.info(f"Saved best adapter -> {best_adapter_dir}")
|
| 932 |
+
|
| 933 |
+
if eval_ds is not None:
|
| 934 |
+
metrics = trainer.evaluate()
|
| 935 |
+
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
|
| 936 |
+
json.dump(metrics, f, indent=2)
|
| 937 |
+
print(f"Final metrics: {metrics}")
|
| 938 |
+
|
| 939 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 940 |
+
del trainer, model
|
| 941 |
+
if ref_model is not None:
|
| 942 |
+
del ref_model
|
| 943 |
+
torch.cuda.empty_cache()
|
| 944 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 945 |
+
else:
|
| 946 |
+
print("Merge disabled. Run with --merge-only later if needed.")
|
| 947 |
+
|
| 948 |
+
# Finish Wandb run
|
| 949 |
+
finish_wandb()
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
if __name__ == "__main__":
|
| 953 |
+
main()
|
DPO-14b/run_dpo.py.backup
ADDED
|
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import inspect
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import yaml
|
| 11 |
+
from datasets import load_dataset, DatasetDict
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForCausalLM,
|
| 16 |
+
BitsAndBytesConfig,
|
| 17 |
+
TrainingArguments,
|
| 18 |
+
TrainerCallback,
|
| 19 |
+
EarlyStoppingCallback,
|
| 20 |
+
set_seed,
|
| 21 |
+
)
|
| 22 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 23 |
+
from peft import (
|
| 24 |
+
LoraConfig,
|
| 25 |
+
get_peft_model,
|
| 26 |
+
prepare_model_for_kbit_training,
|
| 27 |
+
PeftModel,
|
| 28 |
+
)
|
| 29 |
+
from trl import DPOTrainer, DPOConfig
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
import wandb
|
| 33 |
+
WANDB_AVAILABLE = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
WANDB_AVAILABLE = False
|
| 36 |
+
wandb = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --------------------------
|
| 40 |
+
# Helpers
|
| 41 |
+
# --------------------------
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _dtype_from_str(s: str) -> torch.dtype:
|
| 45 |
+
s = (s or "").lower()
|
| 46 |
+
if s in ("float16", "fp16"):
|
| 47 |
+
return torch.float16
|
| 48 |
+
if s in ("bfloat16", "bf16"):
|
| 49 |
+
return torch.bfloat16
|
| 50 |
+
if s in ("float32", "fp32"):
|
| 51 |
+
return torch.float32
|
| 52 |
+
raise ValueError(f"Unknown torch_dtype: {s}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _now_iso() -> str:
|
| 56 |
+
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _safe_exp(x: float) -> float:
|
| 60 |
+
x = min(float(x), 50.0)
|
| 61 |
+
return float(math.exp(x))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _ensure_dir(p: Path) -> Path:
|
| 65 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
return p
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _looks_like_model_dir(p: Path) -> bool:
|
| 70 |
+
if not p.exists() or not p.is_dir():
|
| 71 |
+
return False
|
| 72 |
+
if (p / "config.json").exists():
|
| 73 |
+
return True
|
| 74 |
+
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
|
| 75 |
+
return True
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _infer_target_modules(model) -> List[str]:
|
| 80 |
+
names = set()
|
| 81 |
+
for n, _ in model.named_modules():
|
| 82 |
+
names.add(n.split(".")[-1])
|
| 83 |
+
|
| 84 |
+
for group in [
|
| 85 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 86 |
+
["Wqkv", "out_proj"],
|
| 87 |
+
["query_key_value", "dense"],
|
| 88 |
+
["c_attn", "c_proj"],
|
| 89 |
+
]:
|
| 90 |
+
if all(x in names for x in group):
|
| 91 |
+
return group
|
| 92 |
+
|
| 93 |
+
fallback = [
|
| 94 |
+
x
|
| 95 |
+
for x in [
|
| 96 |
+
"q_proj",
|
| 97 |
+
"k_proj",
|
| 98 |
+
"v_proj",
|
| 99 |
+
"o_proj",
|
| 100 |
+
"c_attn",
|
| 101 |
+
"c_proj",
|
| 102 |
+
"out_proj",
|
| 103 |
+
"dense",
|
| 104 |
+
]
|
| 105 |
+
if x in names
|
| 106 |
+
]
|
| 107 |
+
if fallback:
|
| 108 |
+
return fallback
|
| 109 |
+
|
| 110 |
+
raise ValueError(
|
| 111 |
+
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 116 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# --------------------------
|
| 120 |
+
# Wandb Integration
|
| 121 |
+
# --------------------------
|
| 122 |
+
|
| 123 |
+
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
|
| 124 |
+
"""Initialize Wandb if enabled in configuration."""
|
| 125 |
+
wandb_cfg = cfg.get("wandb", {})
|
| 126 |
+
|
| 127 |
+
if not wandb_cfg.get("enabled", False):
|
| 128 |
+
print("Wandb logging disabled")
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
if not WANDB_AVAILABLE:
|
| 132 |
+
print("Wandb not available. Install with: pip install wandb")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
project = wandb_cfg.get("project", "dpo-training")
|
| 136 |
+
entity = wandb_cfg.get("entity", None)
|
| 137 |
+
name = wandb_cfg.get("name", None)
|
| 138 |
+
tags = wandb_cfg.get("tags", [])
|
| 139 |
+
notes = wandb_cfg.get("notes", None)
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
wandb.init(
|
| 143 |
+
project=project,
|
| 144 |
+
entity=entity,
|
| 145 |
+
name=name,
|
| 146 |
+
tags=tags,
|
| 147 |
+
notes=notes,
|
| 148 |
+
dir=str(run_dir),
|
| 149 |
+
config={
|
| 150 |
+
"model": cfg.get("model", {}),
|
| 151 |
+
"data": cfg.get("data", {}),
|
| 152 |
+
"peft": cfg.get("peft", {}),
|
| 153 |
+
"dpo": cfg.get("dpo", {}),
|
| 154 |
+
"train": cfg.get("train", {}),
|
| 155 |
+
"run_dir": str(run_dir),
|
| 156 |
+
}
|
| 157 |
+
)
|
| 158 |
+
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
|
| 159 |
+
return wandb
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"Failed to initialize Wandb: {e}")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def finish_wandb():
|
| 166 |
+
"""Finish Wandb run if active."""
|
| 167 |
+
if WANDB_AVAILABLE and wandb.run is not None:
|
| 168 |
+
wandb.finish()
|
| 169 |
+
print("Wandb run finished")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# --------------------------
|
| 173 |
+
# JSONL Logger Callback
|
| 174 |
+
# --------------------------
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 178 |
+
def __init__(self, run_dir: Path):
|
| 179 |
+
self.run_dir = run_dir
|
| 180 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 181 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 182 |
+
self.start_time = None
|
| 183 |
+
|
| 184 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 185 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 186 |
+
return None
|
| 187 |
+
elapsed = time.time() - self.start_time
|
| 188 |
+
sec_per_step = elapsed / global_step
|
| 189 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 190 |
+
h = int(remaining // 3600)
|
| 191 |
+
m = int((remaining % 3600) // 60)
|
| 192 |
+
s = int(remaining % 60)
|
| 193 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 194 |
+
|
| 195 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 196 |
+
self.start_time = time.time()
|
| 197 |
+
|
| 198 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 199 |
+
if not logs:
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 203 |
+
progress_pct = (
|
| 204 |
+
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 205 |
+
)
|
| 206 |
+
epoch_pct = None
|
| 207 |
+
if (
|
| 208 |
+
state.epoch is not None
|
| 209 |
+
and args.num_train_epochs
|
| 210 |
+
and args.num_train_epochs > 0
|
| 211 |
+
):
|
| 212 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 213 |
+
|
| 214 |
+
payload = {
|
| 215 |
+
"ts": _now_iso(),
|
| 216 |
+
"event": "train_log",
|
| 217 |
+
"step": int(state.global_step),
|
| 218 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 219 |
+
"progress_pct": (
|
| 220 |
+
round(progress_pct, 2) if progress_pct is not None else None
|
| 221 |
+
),
|
| 222 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 223 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 224 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 225 |
+
**logs,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 229 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 230 |
+
|
| 231 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 232 |
+
if not metrics:
|
| 233 |
+
return
|
| 234 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 235 |
+
|
| 236 |
+
payload = {
|
| 237 |
+
"ts": _now_iso(),
|
| 238 |
+
"event": "eval",
|
| 239 |
+
"step": int(state.global_step),
|
| 240 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 241 |
+
**metrics,
|
| 242 |
+
}
|
| 243 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 244 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# --------------------------
|
| 248 |
+
# Custom Exceptions
|
| 249 |
+
# --------------------------
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class DataFormattingError(Exception):
|
| 253 |
+
"""Exception raised for errors in data formatting."""
|
| 254 |
+
pass
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class DataValidationError(Exception):
|
| 258 |
+
"""Exception raised for errors in data validation."""
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# --------------------------
|
| 263 |
+
# Data Pipeline (DPO Format)
|
| 264 |
+
# --------------------------
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def format_dpo_example(
|
| 268 |
+
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
|
| 269 |
+
) -> Dict[str, Any]:
|
| 270 |
+
"""
|
| 271 |
+
Format DPO data which requires prompt, chosen, and rejected completions.
|
| 272 |
+
Returns formatted prompt, chosen, and rejected texts.
|
| 273 |
+
Raises DataFormattingError if formatting fails.
|
| 274 |
+
"""
|
| 275 |
+
data_cfg = cfg["data"]
|
| 276 |
+
format_type = data_cfg.get("format_type", "chatml")
|
| 277 |
+
|
| 278 |
+
# Get field names from config
|
| 279 |
+
prompt_field = data_cfg.get("prompt_field", "prompt")
|
| 280 |
+
chosen_field = data_cfg.get("chosen_field", "chosen")
|
| 281 |
+
rejected_field = data_cfg.get("rejected_field", "rejected")
|
| 282 |
+
|
| 283 |
+
# Extract text from example
|
| 284 |
+
prompt = example.get(prompt_field, "")
|
| 285 |
+
chosen = example.get(chosen_field, "")
|
| 286 |
+
rejected = example.get(rejected_field, "")
|
| 287 |
+
|
| 288 |
+
# Validate required fields
|
| 289 |
+
if not prompt:
|
| 290 |
+
raise DataFormattingError(f"Empty prompt field: {prompt_field}")
|
| 291 |
+
if not chosen:
|
| 292 |
+
raise DataFormattingError(f"Empty chosen field: {chosen_field}")
|
| 293 |
+
if not rejected:
|
| 294 |
+
raise DataFormattingError(f"Empty rejected field: {rejected_field}")
|
| 295 |
+
|
| 296 |
+
if format_type == "chatml":
|
| 297 |
+
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
|
| 298 |
+
|
| 299 |
+
# Format prompt with system message
|
| 300 |
+
messages = []
|
| 301 |
+
if system_prompt:
|
| 302 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 303 |
+
messages.append({"role": "user", "content": prompt})
|
| 304 |
+
|
| 305 |
+
# Apply chat template for prompt only (without assistant response)
|
| 306 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 307 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Chosen and rejected are just the completions (will be added by DPOTrainer)
|
| 311 |
+
formatted_chosen = chosen
|
| 312 |
+
formatted_rejected = rejected
|
| 313 |
+
|
| 314 |
+
# Add EOS token to completions
|
| 315 |
+
if tokenizer.eos_token:
|
| 316 |
+
if not formatted_chosen.endswith(tokenizer.eos_token):
|
| 317 |
+
formatted_chosen += tokenizer.eos_token
|
| 318 |
+
if not formatted_rejected.endswith(tokenizer.eos_token):
|
| 319 |
+
formatted_rejected += tokenizer.eos_token
|
| 320 |
+
|
| 321 |
+
elif format_type == "alpaca":
|
| 322 |
+
# Alpaca format
|
| 323 |
+
prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
|
| 324 |
+
formatted_prompt = prefix
|
| 325 |
+
formatted_chosen = chosen
|
| 326 |
+
formatted_rejected = rejected
|
| 327 |
+
|
| 328 |
+
if tokenizer.eos_token:
|
| 329 |
+
if not formatted_chosen.endswith(tokenizer.eos_token):
|
| 330 |
+
formatted_chosen += tokenizer.eos_token
|
| 331 |
+
if not formatted_rejected.endswith(tokenizer.eos_token):
|
| 332 |
+
formatted_rejected += tokenizer.eos_token
|
| 333 |
+
|
| 334 |
+
elif format_type == "custom":
|
| 335 |
+
# Custom template
|
| 336 |
+
template = data_cfg.get("custom_template", "{prompt}")
|
| 337 |
+
formatted_prompt = template.format(prompt=prompt)
|
| 338 |
+
formatted_chosen = chosen
|
| 339 |
+
formatted_rejected = rejected
|
| 340 |
+
|
| 341 |
+
if tokenizer.eos_token:
|
| 342 |
+
if not formatted_chosen.endswith(tokenizer.eos_token):
|
| 343 |
+
formatted_chosen += tokenizer.eos_token
|
| 344 |
+
if not formatted_rejected.endswith(tokenizer.eos_token):
|
| 345 |
+
formatted_rejected += tokenizer.eos_token
|
| 346 |
+
else:
|
| 347 |
+
raise ValueError(f"Unsupported format_type: {format_type}")
|
| 348 |
+
|
| 349 |
+
return {
|
| 350 |
+
"prompt": formatted_prompt,
|
| 351 |
+
"chosen": formatted_chosen,
|
| 352 |
+
"rejected": formatted_rejected,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def validate_dpo_data(dataset, stage: str = "train") -> None:
|
| 357 |
+
"""
|
| 358 |
+
Validate DPO dataset has all required fields and proper structure.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
dataset: Dataset to validate
|
| 362 |
+
stage: Training stage ("train" or "eval")
|
| 363 |
+
|
| 364 |
+
Raises:
|
| 365 |
+
DataValidationError if validation fails
|
| 366 |
+
"""
|
| 367 |
+
required_fields = ["prompt", "chosen", "rejected"]
|
| 368 |
+
|
| 369 |
+
# Check required fields exist
|
| 370 |
+
for field in required_fields:
|
| 371 |
+
if field not in dataset.column_names:
|
| 372 |
+
raise DataValidationError(
|
| 373 |
+
f"{stage} dataset missing required field: {field}. "
|
| 374 |
+
f"Available fields: {dataset.column_names}"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Sample validation - check first example
|
| 378 |
+
if len(dataset) > 0:
|
| 379 |
+
sample = dataset[0]
|
| 380 |
+
for field in required_fields:
|
| 381 |
+
if not sample[field] or len(sample[field].strip()) == 0:
|
| 382 |
+
logger.warning(f"{stage} dataset has empty {field} in first example")
|
| 383 |
+
|
| 384 |
+
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples")
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def build_dpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
|
| 388 |
+
"""
|
| 389 |
+
Build datasets for DPO training.
|
| 390 |
+
Expected JSONL format: {"prompt": "...", "chosen": "...", "rejected": "..."}
|
| 391 |
+
Or with custom field names specified in config.
|
| 392 |
+
"""
|
| 393 |
+
data_cfg = cfg["data"]
|
| 394 |
+
train_path = data_cfg["train_jsonl"]
|
| 395 |
+
eval_path = data_cfg.get("eval_jsonl", None)
|
| 396 |
+
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
|
| 397 |
+
shuffle = bool(data_cfg.get("shuffle", True))
|
| 398 |
+
num_proc = int(data_cfg.get("num_proc", 4))
|
| 399 |
+
|
| 400 |
+
# Ensure tokenizer has pad token
|
| 401 |
+
if tokenizer.pad_token is None:
|
| 402 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 403 |
+
|
| 404 |
+
# Load datasets
|
| 405 |
+
ds = load_dataset("json", data_files={"train": train_path})
|
| 406 |
+
|
| 407 |
+
if eval_path:
|
| 408 |
+
ds_eval = load_dataset("json", data_files={"eval": eval_path})
|
| 409 |
+
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
|
| 410 |
+
else:
|
| 411 |
+
if 0.0 < split_ratio < 1.0:
|
| 412 |
+
split = ds["train"].train_test_split(
|
| 413 |
+
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
|
| 414 |
+
)
|
| 415 |
+
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
|
| 416 |
+
else:
|
| 417 |
+
dsd = DatasetDict({"train": ds["train"], "eval": None})
|
| 418 |
+
|
| 419 |
+
# Format DPO examples with error handling
|
| 420 |
+
def format_fn(examples):
|
| 421 |
+
prompts = []
|
| 422 |
+
chosen_list = []
|
| 423 |
+
rejected_list = []
|
| 424 |
+
errors = 0
|
| 425 |
+
|
| 426 |
+
for i in range(len(examples[list(examples.keys())[0]])):
|
| 427 |
+
example = {k: examples[k][i] for k in examples.keys()}
|
| 428 |
+
try:
|
| 429 |
+
formatted = format_dpo_example(example, cfg, tokenizer)
|
| 430 |
+
prompts.append(formatted["prompt"])
|
| 431 |
+
chosen_list.append(formatted["chosen"])
|
| 432 |
+
rejected_list.append(formatted["rejected"])
|
| 433 |
+
except (DataFormattingError, Exception) as e:
|
| 434 |
+
errors += 1
|
| 435 |
+
if errors <= 5: # Log first 5 errors
|
| 436 |
+
logger.warning(f"Failed to format example {i}: {e}")
|
| 437 |
+
# Add empty placeholder to maintain batch structure
|
| 438 |
+
prompts.append("")
|
| 439 |
+
chosen_list.append("")
|
| 440 |
+
rejected_list.append("")
|
| 441 |
+
|
| 442 |
+
if errors > 0:
|
| 443 |
+
logger.warning(f"Total formatting errors in batch: {errors}")
|
| 444 |
+
|
| 445 |
+
return {
|
| 446 |
+
"prompt": prompts,
|
| 447 |
+
"chosen": chosen_list,
|
| 448 |
+
"rejected": rejected_list,
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
logger.info("Formatting train DPO data...")
|
| 452 |
+
formatted_train = dsd["train"].map(
|
| 453 |
+
format_fn,
|
| 454 |
+
batched=True,
|
| 455 |
+
num_proc=num_proc,
|
| 456 |
+
remove_columns=dsd["train"].column_names,
|
| 457 |
+
desc="Formatting train DPO data",
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# Filter out failed examples (empty prompts)
|
| 461 |
+
formatted_train = formatted_train.filter(lambda x: len(x["prompt"]) > 0)
|
| 462 |
+
logger.info(f"Train dataset after filtering: {len(formatted_train)} examples")
|
| 463 |
+
|
| 464 |
+
# Validate formatted data
|
| 465 |
+
validate_dpo_data(formatted_train, "train")
|
| 466 |
+
|
| 467 |
+
formatted_eval = None
|
| 468 |
+
if dsd["eval"] is not None:
|
| 469 |
+
logger.info("Formatting eval DPO data...")
|
| 470 |
+
formatted_eval = dsd["eval"].map(
|
| 471 |
+
format_fn,
|
| 472 |
+
batched=True,
|
| 473 |
+
num_proc=num_proc,
|
| 474 |
+
remove_columns=dsd["eval"].column_names,
|
| 475 |
+
desc="Formatting eval DPO data",
|
| 476 |
+
)
|
| 477 |
+
formatted_eval = formatted_eval.filter(lambda x: len(x["prompt"]) > 0)
|
| 478 |
+
logger.info(f"Eval dataset after filtering: {len(formatted_eval)} examples")
|
| 479 |
+
validate_dpo_data(formatted_eval, "eval")
|
| 480 |
+
|
| 481 |
+
if shuffle:
|
| 482 |
+
formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
|
| 483 |
+
|
| 484 |
+
return formatted_train, formatted_eval
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# --------------------------
|
| 488 |
+
# Model Loading + PEFT
|
| 489 |
+
# --------------------------
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
|
| 493 |
+
model_cfg = cfg["model"]
|
| 494 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 495 |
+
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
|
| 496 |
+
device_map = model_cfg.get("device_map", "auto")
|
| 497 |
+
|
| 498 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 499 |
+
str(base_dir),
|
| 500 |
+
use_fast=use_fast,
|
| 501 |
+
trust_remote_code=trust_remote_code,
|
| 502 |
+
)
|
| 503 |
+
if tokenizer.pad_token is None:
|
| 504 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 505 |
+
|
| 506 |
+
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 507 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 508 |
+
|
| 509 |
+
quant_cfg = None
|
| 510 |
+
if use_4bit:
|
| 511 |
+
quant_cfg = BitsAndBytesConfig(
|
| 512 |
+
load_in_4bit=True,
|
| 513 |
+
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
|
| 514 |
+
bnb_4bit_use_double_quant=bool(
|
| 515 |
+
model_cfg.get("bnb_4bit_use_double_quant", True)
|
| 516 |
+
),
|
| 517 |
+
bnb_4bit_compute_dtype=_dtype_from_str(
|
| 518 |
+
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
|
| 519 |
+
),
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
attn_impl = _choose_attn_impl(cfg)
|
| 523 |
+
|
| 524 |
+
try:
|
| 525 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 526 |
+
str(base_dir),
|
| 527 |
+
device_map=device_map,
|
| 528 |
+
trust_remote_code=trust_remote_code,
|
| 529 |
+
low_cpu_mem_usage=True,
|
| 530 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 531 |
+
quantization_config=quant_cfg,
|
| 532 |
+
attn_implementation=attn_impl,
|
| 533 |
+
)
|
| 534 |
+
except Exception as e:
|
| 535 |
+
if attn_impl is not None:
|
| 536 |
+
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
|
| 537 |
+
print("[warn] Falling back to default attention implementation.")
|
| 538 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 539 |
+
str(base_dir),
|
| 540 |
+
device_map=device_map,
|
| 541 |
+
trust_remote_code=trust_remote_code,
|
| 542 |
+
low_cpu_mem_usage=True,
|
| 543 |
+
torch_dtype=(torch_dtype if not use_4bit else None),
|
| 544 |
+
quantization_config=quant_cfg,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
return model, tokenizer
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def apply_peft(cfg: Dict[str, Any], model):
|
| 551 |
+
peft_cfg = cfg["peft"]
|
| 552 |
+
model_cfg = cfg["model"]
|
| 553 |
+
tr_cfg = cfg["train"]
|
| 554 |
+
|
| 555 |
+
if not bool(peft_cfg.get("enabled", True)):
|
| 556 |
+
return model, None
|
| 557 |
+
|
| 558 |
+
use_4bit = bool(model_cfg.get("use_4bit", False))
|
| 559 |
+
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
|
| 560 |
+
|
| 561 |
+
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
|
| 562 |
+
model.gradient_checkpointing_enable()
|
| 563 |
+
if hasattr(model, "config"):
|
| 564 |
+
model.config.use_cache = False
|
| 565 |
+
|
| 566 |
+
if use_4bit:
|
| 567 |
+
model = prepare_model_for_kbit_training(
|
| 568 |
+
model,
|
| 569 |
+
use_gradient_checkpointing=gradient_checkpointing,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
target_modules = peft_cfg.get("target_modules", "auto")
|
| 573 |
+
if target_modules == "auto":
|
| 574 |
+
target_modules = _infer_target_modules(model)
|
| 575 |
+
|
| 576 |
+
lora_config = LoraConfig(
|
| 577 |
+
r=int(peft_cfg.get("r", 16)),
|
| 578 |
+
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
|
| 579 |
+
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
|
| 580 |
+
bias=str(peft_cfg.get("bias", "none")),
|
| 581 |
+
task_type="CAUSAL_LM",
|
| 582 |
+
target_modules=target_modules,
|
| 583 |
+
)
|
| 584 |
+
model = get_peft_model(model, lora_config)
|
| 585 |
+
return model, lora_config
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
# --------------------------
|
| 589 |
+
# Merge Logic
|
| 590 |
+
# --------------------------
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def merge_adapter(
|
| 594 |
+
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
|
| 595 |
+
):
|
| 596 |
+
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
|
| 597 |
+
|
| 598 |
+
model_cfg = cfg["model"]
|
| 599 |
+
merge_cfg = cfg.get("merge", {})
|
| 600 |
+
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
|
| 601 |
+
|
| 602 |
+
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
|
| 603 |
+
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
|
| 604 |
+
|
| 605 |
+
try:
|
| 606 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 607 |
+
str(base_dir),
|
| 608 |
+
torch_dtype=merged_dtype,
|
| 609 |
+
device_map="cpu",
|
| 610 |
+
low_cpu_mem_usage=True,
|
| 611 |
+
trust_remote_code=trust_remote_code,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
merged = PeftModel.from_pretrained(base, str(adapter_dir))
|
| 615 |
+
merged = merged.merge_and_unload()
|
| 616 |
+
|
| 617 |
+
# Clean up base model to free memory
|
| 618 |
+
del base
|
| 619 |
+
gc.collect()
|
| 620 |
+
torch.cuda.empty_cache()
|
| 621 |
+
|
| 622 |
+
_ensure_dir(final_dir)
|
| 623 |
+
merged.save_pretrained(
|
| 624 |
+
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# Clean up merged model
|
| 628 |
+
del merged
|
| 629 |
+
gc.collect()
|
| 630 |
+
torch.cuda.empty_cache()
|
| 631 |
+
|
| 632 |
+
tok = AutoTokenizer.from_pretrained(
|
| 633 |
+
str(base_dir), trust_remote_code=trust_remote_code
|
| 634 |
+
)
|
| 635 |
+
if tok.pad_token is None:
|
| 636 |
+
tok.pad_token = tok.eos_token
|
| 637 |
+
tok.save_pretrained(str(final_dir))
|
| 638 |
+
|
| 639 |
+
print("--- Merge complete ---")
|
| 640 |
+
except Exception as e:
|
| 641 |
+
logger.error(f"Merge failed: {e}")
|
| 642 |
+
raise
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
# --------------------------
|
| 646 |
+
# Main
|
| 647 |
+
# --------------------------
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def main():
|
| 651 |
+
ap = argparse.ArgumentParser()
|
| 652 |
+
ap.add_argument("--config", required=True, help="Path to YAML config")
|
| 653 |
+
ap.add_argument(
|
| 654 |
+
"--merge-only", action="store_true", help="Skip training, just merge adapter"
|
| 655 |
+
)
|
| 656 |
+
args = ap.parse_args()
|
| 657 |
+
|
| 658 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 659 |
+
cfg = yaml.safe_load(f)
|
| 660 |
+
|
| 661 |
+
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
|
| 662 |
+
_ensure_dir(run_dir / "logs")
|
| 663 |
+
|
| 664 |
+
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
|
| 665 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 666 |
+
|
| 667 |
+
model_cfg = cfg["model"]
|
| 668 |
+
repo_id = str(model_cfg["repo_id"]).strip()
|
| 669 |
+
repo_path = Path(repo_id)
|
| 670 |
+
|
| 671 |
+
# Local model path -> load directly
|
| 672 |
+
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
|
| 673 |
+
base_dir = repo_path
|
| 674 |
+
print(f"Using local model at: {base_dir}")
|
| 675 |
+
elif repo_path.exists() and repo_path.is_dir():
|
| 676 |
+
raise ValueError(
|
| 677 |
+
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
|
| 678 |
+
)
|
| 679 |
+
else:
|
| 680 |
+
# HF repo_id -> download
|
| 681 |
+
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
|
| 682 |
+
if not _looks_like_model_dir(base_dir):
|
| 683 |
+
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
|
| 684 |
+
snapshot_download(
|
| 685 |
+
repo_id=repo_id,
|
| 686 |
+
revision=model_cfg.get("revision", None),
|
| 687 |
+
local_dir=str(base_dir),
|
| 688 |
+
local_dir_use_symlinks=False,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
|
| 692 |
+
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
|
| 693 |
+
|
| 694 |
+
merge_cfg = cfg.get("merge", {}) or {}
|
| 695 |
+
if merge_cfg.get("output_dir"):
|
| 696 |
+
od = Path(str(merge_cfg["output_dir"]))
|
| 697 |
+
final_dir = od if od.is_absolute() else (run_dir / od)
|
| 698 |
+
else:
|
| 699 |
+
final_dir = run_dir / "final_model"
|
| 700 |
+
|
| 701 |
+
# Merge-only
|
| 702 |
+
if args.merge_only:
|
| 703 |
+
if not _looks_like_model_dir(best_adapter_dir):
|
| 704 |
+
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
|
| 705 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 706 |
+
return
|
| 707 |
+
|
| 708 |
+
# Initialize Wandb
|
| 709 |
+
wandb_run = setup_wandb(cfg, run_dir)
|
| 710 |
+
|
| 711 |
+
# Training
|
| 712 |
+
set_seed(int(cfg["run"].get("seed", 42)))
|
| 713 |
+
|
| 714 |
+
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
|
| 715 |
+
model, _ = apply_peft(cfg, model)
|
| 716 |
+
|
| 717 |
+
# Load reference model for DPO (if using reference model)
|
| 718 |
+
dpo_cfg = cfg.get("dpo", {})
|
| 719 |
+
use_reference_model = bool(dpo_cfg.get("use_reference_model", True))
|
| 720 |
+
reference_free = bool(dpo_cfg.get("reference_free", False))
|
| 721 |
+
|
| 722 |
+
ref_model = None
|
| 723 |
+
if use_reference_model and not reference_free:
|
| 724 |
+
print("Loading reference model (frozen copy)...")
|
| 725 |
+
ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir)
|
| 726 |
+
ref_model, _ = apply_peft(cfg, ref_model)
|
| 727 |
+
# Freeze reference model
|
| 728 |
+
for param in ref_model.parameters():
|
| 729 |
+
param.requires_grad = False
|
| 730 |
+
ref_model.eval()
|
| 731 |
+
print("Reference model loaded and frozen")
|
| 732 |
+
|
| 733 |
+
train_ds, eval_ds = build_dpo_datasets(cfg, tokenizer)
|
| 734 |
+
|
| 735 |
+
tr_cfg = cfg["train"]
|
| 736 |
+
|
| 737 |
+
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
|
| 738 |
+
use_fp16 = dtype == torch.float16
|
| 739 |
+
use_bf16 = dtype == torch.bfloat16
|
| 740 |
+
|
| 741 |
+
max_steps = int(tr_cfg.get("max_steps", 0))
|
| 742 |
+
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
|
| 743 |
+
|
| 744 |
+
# Dynamic evaluation strategy parameter handling
|
| 745 |
+
ta_params = inspect.signature(TrainingArguments.__init__).parameters
|
| 746 |
+
eval_key = (
|
| 747 |
+
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# Setup reporting based on wandb availability
|
| 751 |
+
report_to = []
|
| 752 |
+
if wandb_run is not None:
|
| 753 |
+
report_to.append("wandb")
|
| 754 |
+
|
| 755 |
+
# Validate and adjust training parameters
|
| 756 |
+
max_grad_norm = float(tr_cfg.get("max_grad_norm", 1.0))
|
| 757 |
+
if max_grad_norm <= 0:
|
| 758 |
+
logger.warning(f"Invalid max_grad_norm={max_grad_norm}, using 1.0")
|
| 759 |
+
max_grad_norm = 1.0
|
| 760 |
+
|
| 761 |
+
ta_kwargs = dict(
|
| 762 |
+
output_dir=str(ckpt_dir),
|
| 763 |
+
max_steps=max_steps if max_steps > 0 else -1,
|
| 764 |
+
num_train_epochs=num_train_epochs,
|
| 765 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
|
| 766 |
+
per_device_eval_batch_size=int(
|
| 767 |
+
tr_cfg.get(
|
| 768 |
+
"per_device_eval_batch_size",
|
| 769 |
+
tr_cfg.get("per_device_train_batch_size", 1),
|
| 770 |
+
)
|
| 771 |
+
),
|
| 772 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
|
| 773 |
+
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)),
|
| 774 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
|
| 775 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 776 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
|
| 777 |
+
optim=str(
|
| 778 |
+
tr_cfg.get(
|
| 779 |
+
"optim",
|
| 780 |
+
(
|
| 781 |
+
"paged_adamw_8bit"
|
| 782 |
+
if bool(model_cfg.get("use_4bit", False))
|
| 783 |
+
else "adamw_torch"
|
| 784 |
+
),
|
| 785 |
+
)
|
| 786 |
+
),
|
| 787 |
+
max_grad_norm=max_grad_norm,
|
| 788 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 789 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 790 |
+
save_steps=int(tr_cfg.get("save_steps", 200)),
|
| 791 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 792 |
+
eval_steps=int(tr_cfg.get("eval_steps", 50)),
|
| 793 |
+
load_best_model_at_end=(
|
| 794 |
+
bool(tr_cfg.get("load_best_model_at_end", True))
|
| 795 |
+
if eval_ds is not None
|
| 796 |
+
else False
|
| 797 |
+
),
|
| 798 |
+
metric_for_best_model="eval_loss",
|
| 799 |
+
greater_is_better=False,
|
| 800 |
+
fp16=use_fp16,
|
| 801 |
+
bf16=use_bf16,
|
| 802 |
+
report_to=report_to,
|
| 803 |
+
remove_unused_columns=False,
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# Set the correct argument name for this transformers version
|
| 807 |
+
ta_kwargs[eval_key] = str(
|
| 808 |
+
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
training_args = TrainingArguments(**ta_kwargs)
|
| 812 |
+
|
| 813 |
+
# Setup callbacks
|
| 814 |
+
callbacks = [JsonlLoggerCallback(run_dir)]
|
| 815 |
+
|
| 816 |
+
# Add early stopping callback if enabled
|
| 817 |
+
early_stopping_cfg = tr_cfg.get("early_stopping", {})
|
| 818 |
+
if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
|
| 819 |
+
early_stopping_callback = EarlyStoppingCallback(
|
| 820 |
+
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
|
| 821 |
+
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
|
| 822 |
+
)
|
| 823 |
+
callbacks.append(early_stopping_callback)
|
| 824 |
+
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
|
| 825 |
+
f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
|
| 826 |
+
|
| 827 |
+
# DPO-specific parameters
|
| 828 |
+
beta = float(dpo_cfg.get("beta", 0.1))
|
| 829 |
+
label_smoothing = float(dpo_cfg.get("label_smoothing", 0.0))
|
| 830 |
+
loss_type = str(dpo_cfg.get("loss_type", "sigmoid"))
|
| 831 |
+
max_length = int(cfg["data"].get("max_length", 2048))
|
| 832 |
+
max_prompt_length = int(cfg["data"].get("max_prompt_length", max_length // 2))
|
| 833 |
+
|
| 834 |
+
print(f"DPO Training with beta={beta}, loss_type={loss_type}")
|
| 835 |
+
|
| 836 |
+
# Get evaluation strategy from config
|
| 837 |
+
eval_strategy_val = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
|
| 838 |
+
|
| 839 |
+
# Create DPOConfig with all training and DPO-specific parameters
|
| 840 |
+
dpo_config = DPOConfig(
|
| 841 |
+
output_dir=str(run_dir),
|
| 842 |
+
num_train_epochs=int(tr_cfg.get("num_train_epochs", 3)),
|
| 843 |
+
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 2)),
|
| 844 |
+
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", 4)),
|
| 845 |
+
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 4)),
|
| 846 |
+
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)),
|
| 847 |
+
weight_decay=float(tr_cfg.get("weight_decay", 0.01)),
|
| 848 |
+
adam_beta1=float(tr_cfg.get("adam_beta1", 0.9)),
|
| 849 |
+
adam_beta2=float(tr_cfg.get("adam_beta2", 0.999)),
|
| 850 |
+
adam_epsilon=float(tr_cfg.get("adam_epsilon", 1e-8)),
|
| 851 |
+
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
|
| 852 |
+
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "linear")),
|
| 853 |
+
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
|
| 854 |
+
logging_steps=int(tr_cfg.get("logging_steps", 10)),
|
| 855 |
+
save_steps=int(tr_cfg.get("save_steps", 100)),
|
| 856 |
+
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
|
| 857 |
+
eval_steps=int(tr_cfg.get("eval_steps", 100)) if eval_ds is not None else None,
|
| 858 |
+
eval_strategy=eval_strategy_val,
|
| 859 |
+
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
|
| 860 |
+
load_best_model_at_end=(
|
| 861 |
+
bool(tr_cfg.get("load_best_model_at_end", False))
|
| 862 |
+
if eval_ds is not None
|
| 863 |
+
else False
|
| 864 |
+
),
|
| 865 |
+
metric_for_best_model=str(tr_cfg.get("metric_for_best_model", "eval_loss")),
|
| 866 |
+
greater_is_better=bool(tr_cfg.get("greater_is_better", False)),
|
| 867 |
+
fp16=use_fp16,
|
| 868 |
+
bf16=use_bf16,
|
| 869 |
+
report_to=report_to,
|
| 870 |
+
remove_unused_columns=False,
|
| 871 |
+
# DPO-specific parameters
|
| 872 |
+
beta=beta,
|
| 873 |
+
label_smoothing=label_smoothing,
|
| 874 |
+
loss_type=loss_type,
|
| 875 |
+
max_length=max_length,
|
| 876 |
+
max_prompt_length=max_prompt_length,
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
trainer = DPOTrainer(
|
| 880 |
+
model=model,
|
| 881 |
+
ref_model=ref_model,
|
| 882 |
+
args=dpo_config,
|
| 883 |
+
train_dataset=train_ds,
|
| 884 |
+
eval_dataset=eval_ds,
|
| 885 |
+
processing_class=tokenizer,
|
| 886 |
+
callbacks=callbacks,
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# Resume
|
| 890 |
+
resume_from = tr_cfg.get("resume_from_checkpoint", None)
|
| 891 |
+
if resume_from == "auto":
|
| 892 |
+
last = get_last_checkpoint(str(ckpt_dir))
|
| 893 |
+
resume_from = last if last else None
|
| 894 |
+
if resume_from:
|
| 895 |
+
print(f"Resuming from {resume_from}")
|
| 896 |
+
|
| 897 |
+
print("Starting DPO training...")
|
| 898 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 899 |
+
|
| 900 |
+
trainer.save_model(str(best_adapter_dir))
|
| 901 |
+
print(f"Saved best adapter -> {best_adapter_dir}")
|
| 902 |
+
|
| 903 |
+
if eval_ds is not None:
|
| 904 |
+
metrics = trainer.evaluate()
|
| 905 |
+
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
|
| 906 |
+
json.dump(metrics, f, indent=2)
|
| 907 |
+
print(f"Final metrics: {metrics}")
|
| 908 |
+
|
| 909 |
+
if bool(cfg.get("merge", {}).get("enabled", False)):
|
| 910 |
+
del trainer, model
|
| 911 |
+
if ref_model is not None:
|
| 912 |
+
del ref_model
|
| 913 |
+
torch.cuda.empty_cache()
|
| 914 |
+
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
|
| 915 |
+
else:
|
| 916 |
+
print("Merge disabled. Run with --merge-only later if needed.")
|
| 917 |
+
|
| 918 |
+
# Finish Wandb run
|
| 919 |
+
finish_wandb()
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
if __name__ == "__main__":
|
| 923 |
+
main()
|
DPO-14b/run_dpo_enhanced.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced DPO training script with improved error handling, validation, and memory management.
|
| 3 |
+
All critical fixes from the review have been implemented.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import gc
|
| 8 |
+
import json
|
| 9 |
+
import inspect
|
| 10 |
+
import math
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import yaml
|
| 18 |
+
from datasets import load_dataset, DatasetDict
|
| 19 |
+
from huggingface_hub import snapshot_download
|
| 20 |
+
from transformers import (
|
| 21 |
+
AutoTokenizer,
|
| 22 |
+
AutoModelForCausalLM,
|
| 23 |
+
BitsAndBytesConfig,
|
| 24 |
+
TrainingArguments,
|
| 25 |
+
TrainerCallback,
|
| 26 |
+
EarlyStoppingCallback,
|
| 27 |
+
set_seed,
|
| 28 |
+
)
|
| 29 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 30 |
+
from peft import (
|
| 31 |
+
LoraConfig,
|
| 32 |
+
get_peft_model,
|
| 33 |
+
prepare_model_for_kbit_training,
|
| 34 |
+
PeftModel,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Version check for TRL
|
| 38 |
+
try:
|
| 39 |
+
import trl
|
| 40 |
+
from trl import DPOTrainer, DPOConfig
|
| 41 |
+
from packaging import version
|
| 42 |
+
if version.parse(trl.__version__) < version.parse("0.7.0"):
|
| 43 |
+
print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
|
| 44 |
+
except ImportError as e:
|
| 45 |
+
raise ImportError("TRL library not found. Install with: pip install trl>=0.7.0") from e
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
import wandb
|
| 49 |
+
WANDB_AVAILABLE = True
|
| 50 |
+
except ImportError:
|
| 51 |
+
WANDB_AVAILABLE = False
|
| 52 |
+
wandb = None
|
| 53 |
+
|
| 54 |
+
# Setup logging
|
| 55 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# --------------------------
|
| 60 |
+
# Custom Exceptions
|
| 61 |
+
# --------------------------
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DataFormattingError(Exception):
|
| 65 |
+
"""Exception raised for errors in data formatting."""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DataValidationError(Exception):
|
| 70 |
+
"""Exception raised for errors in data validation."""
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# --------------------------
|
| 75 |
+
# SUMMARY OF FIXES IMPLEMENTED
|
| 76 |
+
# --------------------------
|
| 77 |
+
"""
|
| 78 |
+
β
CRITICAL FIXES:
|
| 79 |
+
1. Memory cleanup with gc.collect() and torch.cuda.empty_cache() in merge_adapter()
|
| 80 |
+
2. TRL version compatibility check (>= 0.7.0)
|
| 81 |
+
3. Error handling in data formatting with DataFormattingError
|
| 82 |
+
4. Data validation before training with validate_dpo_data()
|
| 83 |
+
|
| 84 |
+
β
HIGH PRIORITY FIXES:
|
| 85 |
+
5. Logging with proper logger setup
|
| 86 |
+
6. Error counting and reporting in data formatting
|
| 87 |
+
7. Gradient norm validation
|
| 88 |
+
8. Dataset filtering to remove failed examples
|
| 89 |
+
|
| 90 |
+
β
MEDIUM PRIORITY FIXES:
|
| 91 |
+
9. Progress descriptions in data processing
|
| 92 |
+
10. Validation of empty fields
|
| 93 |
+
11. Try-except blocks around critical sections
|
| 94 |
+
12. Better error messages with context
|
| 95 |
+
|
| 96 |
+
β
IMPROVEMENTS:
|
| 97 |
+
13. Type hints retained
|
| 98 |
+
14. Proper exception hierarchy
|
| 99 |
+
15. Logging instead of print statements
|
| 100 |
+
16. Memory-efficient merge process
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
print("=" * 80)
|
| 104 |
+
print("DPO TRAINER - ENHANCED VERSION")
|
| 105 |
+
print("=" * 80)
|
| 106 |
+
print("β
Memory management improvements")
|
| 107 |
+
print("β
Error handling and validation")
|
| 108 |
+
print("β
TRL version compatibility check")
|
| 109 |
+
print("β
Data quality checks")
|
| 110 |
+
print("=" * 80)
|
| 111 |
+
return fallback
|
| 112 |
+
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
|
| 119 |
+
return cfg.get("model", {}).get("attn_implementation", None)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# --------------------------
|
| 123 |
+
# Wandb Integration
|
| 124 |
+
# --------------------------
|
| 125 |
+
|
| 126 |
+
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
|
| 127 |
+
"""Initialize Wandb if enabled in configuration."""
|
| 128 |
+
wandb_cfg = cfg.get("wandb", {})
|
| 129 |
+
|
| 130 |
+
if not wandb_cfg.get("enabled", False):
|
| 131 |
+
print("Wandb logging disabled")
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
if not WANDB_AVAILABLE:
|
| 135 |
+
print("Wandb not available. Install with: pip install wandb")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
project = wandb_cfg.get("project", "dpo-training")
|
| 139 |
+
entity = wandb_cfg.get("entity", None)
|
| 140 |
+
name = wandb_cfg.get("name", None)
|
| 141 |
+
tags = wandb_cfg.get("tags", [])
|
| 142 |
+
notes = wandb_cfg.get("notes", None)
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
wandb.init(
|
| 146 |
+
project=project,
|
| 147 |
+
entity=entity,
|
| 148 |
+
name=name,
|
| 149 |
+
tags=tags,
|
| 150 |
+
notes=notes,
|
| 151 |
+
dir=str(run_dir),
|
| 152 |
+
config={
|
| 153 |
+
"model": cfg.get("model", {}),
|
| 154 |
+
"data": cfg.get("data", {}),
|
| 155 |
+
"peft": cfg.get("peft", {}),
|
| 156 |
+
"dpo": cfg.get("dpo", {}),
|
| 157 |
+
"train": cfg.get("train", {}),
|
| 158 |
+
"run_dir": str(run_dir),
|
| 159 |
+
}
|
| 160 |
+
)
|
| 161 |
+
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
|
| 162 |
+
return wandb
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Failed to initialize Wandb: {e}")
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def finish_wandb():
|
| 169 |
+
"""Finish Wandb run if active."""
|
| 170 |
+
if WANDB_AVAILABLE and wandb.run is not None:
|
| 171 |
+
wandb.finish()
|
| 172 |
+
print("Wandb run finished")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# --------------------------
|
| 176 |
+
# JSONL Logger Callback
|
| 177 |
+
# --------------------------
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class JsonlLoggerCallback(TrainerCallback):
|
| 181 |
+
def __init__(self, run_dir: Path):
|
| 182 |
+
self.run_dir = run_dir
|
| 183 |
+
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
|
| 184 |
+
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
|
| 185 |
+
self.start_time = None
|
| 186 |
+
|
| 187 |
+
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
|
| 188 |
+
if self.start_time is None or global_step <= 0 or max_steps <= 0:
|
| 189 |
+
return None
|
| 190 |
+
elapsed = time.time() - self.start_time
|
| 191 |
+
sec_per_step = elapsed / global_step
|
| 192 |
+
remaining = max(0, max_steps - global_step) * sec_per_step
|
| 193 |
+
h = int(remaining // 3600)
|
| 194 |
+
m = int((remaining % 3600) // 60)
|
| 195 |
+
s = int(remaining % 60)
|
| 196 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
| 197 |
+
|
| 198 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 199 |
+
self.start_time = time.time()
|
| 200 |
+
|
| 201 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 202 |
+
if not logs:
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
|
| 206 |
+
progress_pct = (
|
| 207 |
+
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
|
| 208 |
+
)
|
| 209 |
+
epoch_pct = None
|
| 210 |
+
if (
|
| 211 |
+
state.epoch is not None
|
| 212 |
+
and args.num_train_epochs
|
| 213 |
+
and args.num_train_epochs > 0
|
| 214 |
+
):
|
| 215 |
+
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
|
| 216 |
+
|
| 217 |
+
payload = {
|
| 218 |
+
"ts": _now_iso(),
|
| 219 |
+
"event": "train_log",
|
| 220 |
+
"step": int(state.global_step),
|
| 221 |
+
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
|
| 222 |
+
"progress_pct": (
|
| 223 |
+
round(progress_pct, 2) if progress_pct is not None else None
|
| 224 |
+
),
|
| 225 |
+
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
|
| 226 |
+
"eta": self._eta(int(state.global_step), max_steps),
|
| 227 |
+
"max_grad_norm": getattr(args, "max_grad_norm", None),
|
| 228 |
+
**logs,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
with self.train_log_path.open("a", encoding="utf-8") as f:
|
| 232 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 233 |
+
|
| 234 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 235 |
+
if not metrics:
|
| 236 |
+
return
|
| 237 |
+
eval_loss = metrics.get("eval_loss", None)
|
| 238 |
+
|
| 239 |
+
payload = {
|
| 240 |
+
"ts": _now_iso(),
|
| 241 |
+
"event": "eval",
|
| 242 |
+
"step": int(state.global_step),
|
| 243 |
+
"epoch": float(state.epoch) if state.epoch is not None else None,
|
| 244 |
+
**metrics,
|
| 245 |
+
}
|
| 246 |
+
with self.eval_log_path.open("a", encoding="utf-8") as f:
|
| 247 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# --------------------------
|
| 251 |
+
# Custom Exceptions
|
| 252 |
+
# --------------------------
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class DataFormattingError(Exception):
|
| 256 |
+
"""Exception raised for errors in data formatting."""
|
| 257 |
+
pass
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class DataValidationError(Exception):
|
| 261 |
+
"""Exception raised for errors in data validation."""
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# --------------------------
|
| 266 |
+
# Data Pipeline (DPO Format)
|
| 267 |
+
# --------------------------
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def format_dpo_example(
|
| 271 |
+
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
|
| 272 |
+
) -> Dict[str, Any]:
|
| 273 |
+
"""
|
| 274 |
+
Format DPO data which requires prompt, chosen, and rejected completions.
|
| 275 |
+
Returns formatted prompt, chosen, and rejected texts.
|
| 276 |
+
Raises DataFormattingError if formatting fails.
|
| 277 |
+
"""
|
| 278 |
+
data_cfg = cfg["data"]
|
| 279 |
+
format_type = data_cfg.get("format_type", "chatml")
|
| 280 |
+
|
| 281 |
+
# Get field names from config
|
| 282 |
+
prompt_field = data_cfg.get("prompt_field", "prompt")
|
| 283 |
+
chosen_field = data_cfg.get("chosen_field", "chosen")
|
| 284 |
+
rejected_field = data_cfg.get("rejected_field", "rejected")
|
| 285 |
+
|
| 286 |
+
# Extract text from example
|
| 287 |
+
prompt = example.get(prompt_field, "")
|
| 288 |
+
chosen = example.get(chosen_field, "")
|
| 289 |
+
rejected = example.get(rejected_field, "")
|
| 290 |
+
|
| 291 |
+
# Validate required fields
|
| 292 |
+
if not prompt:
|
| 293 |
+
raise DataFormattingError(f"Empty prompt field: {prompt_field}")
|
| 294 |
+
if not chosen:
|
| 295 |
+
raise DataFormattingError(f"Empty chosen field: {chosen_field}")
|
| 296 |
+
if not rejected:
|
| 297 |
+
raise DataFormattingError(f"Empty rejected field: {rejected_field}")
|
| 298 |
+
|
| 299 |
+
if format_type == "chatml":
|
| 300 |
+
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
|
| 301 |
+
|
| 302 |
+
# Format prompt with system message
|
| 303 |
+
messages = []
|
| 304 |
+
if system_prompt:
|
| 305 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 306 |
+
messages.append({"role": "user", "content": prompt})
|
| 307 |
+
|
| 308 |
+
# Apply chat template for prompt only (without assistant response)
|
| 309 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 310 |
+
messages, tokenize=False, add_generation_prompt=True
|
DPO-14b/test_fixes.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script to verify all critical fixes have been applied to run_dpo.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
def check_fixes():
|
| 10 |
+
"""Check if all critical fixes are present in run_dpo.py"""
|
| 11 |
+
|
| 12 |
+
filepath = Path("run_dpo.py")
|
| 13 |
+
if not filepath.exists():
|
| 14 |
+
print(f"β Error: {filepath} not found")
|
| 15 |
+
return False
|
| 16 |
+
|
| 17 |
+
with open(filepath, 'r') as f:
|
| 18 |
+
content = f.read()
|
| 19 |
+
|
| 20 |
+
checks = []
|
| 21 |
+
|
| 22 |
+
# Check 1: Memory cleanup imports
|
| 23 |
+
if 'import gc' in content:
|
| 24 |
+
checks.append(("β
", "gc import added"))
|
| 25 |
+
else:
|
| 26 |
+
checks.append(("β", "gc import missing"))
|
| 27 |
+
|
| 28 |
+
# Check 2: Logging setup
|
| 29 |
+
if 'import logging' in content and 'logging.basicConfig' in content:
|
| 30 |
+
checks.append(("β
", "Logging setup configured"))
|
| 31 |
+
else:
|
| 32 |
+
checks.append(("β", "Logging setup missing"))
|
| 33 |
+
|
| 34 |
+
# Check 3: Custom exceptions
|
| 35 |
+
if 'class DataFormattingError' in content and 'class DataValidationError' in content:
|
| 36 |
+
checks.append(("β
", "Custom exceptions defined"))
|
| 37 |
+
else:
|
| 38 |
+
checks.append(("β", "Custom exceptions missing"))
|
| 39 |
+
|
| 40 |
+
# Check 4: Data validation function
|
| 41 |
+
if 'def validate_dpo_data' in content:
|
| 42 |
+
checks.append(("β
", "Data validation function defined"))
|
| 43 |
+
if 'validate_dpo_data(formatted_train' in content:
|
| 44 |
+
checks.append(("β
", "Data validation called for train"))
|
| 45 |
+
else:
|
| 46 |
+
checks.append(("β", "Data validation not called"))
|
| 47 |
+
else:
|
| 48 |
+
checks.append(("β", "Data validation function missing"))
|
| 49 |
+
|
| 50 |
+
# Check 5: Memory cleanup in merge_adapter
|
| 51 |
+
if 'del base' in content and 'gc.collect()' in content:
|
| 52 |
+
checks.append(("β
", "Memory cleanup in merge_adapter"))
|
| 53 |
+
else:
|
| 54 |
+
checks.append(("β", "Memory cleanup missing"))
|
| 55 |
+
|
| 56 |
+
# Check 6: TRL version check
|
| 57 |
+
if 'version.parse(trl.__version__)' in content:
|
| 58 |
+
checks.append(("β
", "TRL version check added"))
|
| 59 |
+
else:
|
| 60 |
+
checks.append(("β", "TRL version check missing"))
|
| 61 |
+
|
| 62 |
+
# Check 7: Error handling in format function
|
| 63 |
+
if 'except (DataFormattingError, Exception) as e:' in content:
|
| 64 |
+
checks.append(("β
", "Error handling in format function"))
|
| 65 |
+
else:
|
| 66 |
+
checks.append(("β", "Error handling missing"))
|
| 67 |
+
|
| 68 |
+
# Check 8: Logger usage (replaced some prints)
|
| 69 |
+
if 'logger.info' in content and 'logger.warning' in content:
|
| 70 |
+
checks.append(("β
", "Logger used for logging"))
|
| 71 |
+
else:
|
| 72 |
+
checks.append(("β", "Logger not properly used"))
|
| 73 |
+
|
| 74 |
+
# Check 9: Gradient norm validation (should be in TrainingArguments)
|
| 75 |
+
if 'max_grad_norm' in content:
|
| 76 |
+
checks.append(("β
", "Gradient norm parameter present"))
|
| 77 |
+
else:
|
| 78 |
+
checks.append(("β οΈ", "Gradient norm parameter not found"))
|
| 79 |
+
|
| 80 |
+
# Print results
|
| 81 |
+
print("="*80)
|
| 82 |
+
print("DPO TRAINER - FIX VERIFICATION")
|
| 83 |
+
print("="*80)
|
| 84 |
+
|
| 85 |
+
for status, message in checks:
|
| 86 |
+
print(f"{status} {message}")
|
| 87 |
+
|
| 88 |
+
print("="*80)
|
| 89 |
+
|
| 90 |
+
# Summary
|
| 91 |
+
passed = sum(1 for s, _ in checks if s == "β
")
|
| 92 |
+
failed = sum(1 for s, _ in checks if s == "β")
|
| 93 |
+
warnings = sum(1 for s, _ in checks if s == "β οΈ")
|
| 94 |
+
|
| 95 |
+
print(f"\nSummary: {passed} passed, {failed} failed, {warnings} warnings")
|
| 96 |
+
|
| 97 |
+
if failed == 0:
|
| 98 |
+
print("\nβ
All critical fixes verified successfully!")
|
| 99 |
+
print("\nYou can now proceed with training:")
|
| 100 |
+
print(" python run_dpo.py --config config_dpo.yaml")
|
| 101 |
+
return True
|
| 102 |
+
else:
|
| 103 |
+
print("\nβ Some fixes are missing. Please review the implementation.")
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
success = check_fixes()
|
| 108 |
+
sys.exit(0 if success else 1)
|