DPO Training for Code Analysis
This folder contains a Direct Preference Optimization (DPO) trainer for fine-tuning models on code analysis tasks with preference pairs.
Overview
DPO training uses preference pairs (chosen/rejected responses) to optimize the model to prefer better outputs over worse ones. This is particularly useful for tasks where we have multiple responses with different quality levels.
Files
run_dpo.py- Main DPO training scriptconfig_dpo.yaml- Configuration file for DPO trainingf1_score_utils.py- Utilities for computing F1 scores and creating preference pairsrequirements.txt- Python dependenciesdpo_dataset.jsonl- Sample DPO dataset
Data Format
DPO requires data in the following format:
{
"prompt": "##TASK\n<task description>",
"chosen": "<better response with correct file selections>",
"rejected": "<worse response with incorrect file selections>",
"chosen_f1": 1.0,
"rejected_f1": 0.5
}
Creating DPO Data from SFT Data
You can use the F1 score utility to create DPO pairs from multiple model generations:
from f1_score_utils import create_dpo_pairs_from_generations
prompt = "##TASK\nAdd webhook support..."
generations = [output1, output2, output3, output4] # Multiple model outputs
ground_truth = "##OUTPUT\n...\n##SELECT\n..."
pairs = create_dpo_pairs_from_generations(
prompt, generations, ground_truth, min_f1_difference=0.1
)
F1 Score Ranking
The F1 score is computed at the file level:
- Precision: Correct files / Total predicted files
- Recall: Correct files / Total ground truth files
- F1: Harmonic mean of precision and recall
Files are extracted from the ##SELECT section:
##SELECT
crates/router/src/webhooks.rs::process_webhook
crates/common_enums/src/enums.rs::EventClass
<EOS>
Installation
pip install -r requirements.txt
Usage
1. Prepare DPO Dataset
You need to generate multiple outputs for each prompt and rank them by F1 score:
from f1_score_utils import compute_file_level_f1, rank_outputs_by_f1
# Rank outputs
ranked = rank_outputs_by_f1(outputs, ground_truth)
for output, f1, metrics in ranked:
print(f"F1: {f1:.3f} - {metrics['true_positives']} correct files")
2. Configure Training
Edit config_dpo.yaml:
- Set
model.repo_idto your SFT model path - Adjust
dpo.beta(temperature parameter, default 0.1) - Set
dpo.loss_type(sigmoid, hinge, ipo, kto) - Configure training hyperparameters
3. Run Training
python run_dpo.py --config config_dpo.yaml
4. Merge Adapter (Optional)
If training is complete and you want to merge the adapter:
python run_dpo.py --config config_dpo.yaml --merge-only
Configuration
DPO Parameters
beta: Temperature for DPO loss (higher = less aggressive preference learning)label_smoothing: Smoothing factor for labelsloss_type: Type of loss functionsigmoid: Standard DPO loss (default)hinge: Margin-based lossipo: Identity Policy Optimizationkto: Kahneman-Tversky Optimization
use_reference_model: Whether to use a frozen reference model
Training Tips
- Learning Rate: Use lower LR than SFT (e.g., 5e-5 vs 2e-4)
- Beta: Start with 0.1, increase for less aggressive learning
- Batch Size: Larger batches are more stable
- Data Quality: Ensure significant F1 difference between chosen/rejected (≥0.1)
Output
Training outputs:
runs/dpo_run_14b_v1/checkpoints/- Training checkpointsruns/dpo_run_14b_v1/best_adapter/- Best adapter weightsruns/dpo_run_14b_v1/merged_14b_dpo_lora/- Merged modelruns/dpo_run_14b_v1/logs/- Training logs (JSONL format)
WandB Integration
Enable experiment tracking in config_dpo.yaml:
wandb:
enabled: true
project: "dpo-training"
tags: ["dpo-lora", "preference-optimization"]
Example: Generate DPO Data
import json
from f1_score_utils import compute_file_level_f1, create_dpo_pairs_from_generations
# Load SFT data
with open("instruct_data.jsonl") as f:
for line in f:
data = json.loads(line)
prompt = data["input"]
ground_truth = data["output"]
# Generate multiple outputs with your model
generations = generate_multiple_outputs(prompt, num_samples=4)
# Create preference pairs
pairs = create_dpo_pairs_from_generations(
prompt, generations, ground_truth, min_f1_difference=0.1
)
# Save pairs
with open("dpo_dataset.jsonl", "a") as out:
for pair in pairs:
out.write(json.dumps(pair) + "\n")
Troubleshooting
- OOM Errors: Reduce batch size or enable gradient checkpointing
- No Improvement: Check F1 score differences in data, increase beta
- Unstable Training: Lower learning rate, increase warmup ratio
- Reference Model Issues: Set
use_reference_model: falseto use implicit reference
References
- DPO Paper: Direct Preference Optimization
- TRL Library: HuggingFace TRL