File size: 5,216 Bytes
4eae728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# DPO Training for Code Analysis

This folder contains a Direct Preference Optimization (DPO) trainer for fine-tuning models on code analysis tasks with preference pairs.

## Overview

DPO training uses preference pairs (chosen/rejected responses) to optimize the model to prefer better outputs over worse ones. This is particularly useful for tasks where we have multiple responses with different quality levels.

## Files

- `run_dpo.py` - Main DPO training script
- `config_dpo.yaml` - Configuration file for DPO training
- `f1_score_utils.py` - Utilities for computing F1 scores and creating preference pairs
- `requirements.txt` - Python dependencies
- `dpo_dataset.jsonl` - Sample DPO dataset

## Data Format

DPO requires data in the following format:

```jsonl
{
  "prompt": "##TASK\n<task description>",
  "chosen": "<better response with correct file selections>",
  "rejected": "<worse response with incorrect file selections>",
  "chosen_f1": 1.0,
  "rejected_f1": 0.5
}
```

### Creating DPO Data from SFT Data

You can use the F1 score utility to create DPO pairs from multiple model generations:

```python
from f1_score_utils import create_dpo_pairs_from_generations

prompt = "##TASK\nAdd webhook support..."
generations = [output1, output2, output3, output4]  # Multiple model outputs
ground_truth = "##OUTPUT\n...\n##SELECT\n..."

pairs = create_dpo_pairs_from_generations(
    prompt, generations, ground_truth, min_f1_difference=0.1
)
```

## F1 Score Ranking

The F1 score is computed at the **file level**:
- **Precision**: Correct files / Total predicted files
- **Recall**: Correct files / Total ground truth files
- **F1**: Harmonic mean of precision and recall

Files are extracted from the `##SELECT` section:
```
##SELECT
crates/router/src/webhooks.rs::process_webhook
crates/common_enums/src/enums.rs::EventClass
<EOS>
```

## Installation

```bash
pip install -r requirements.txt
```

## Usage

### 1. Prepare DPO Dataset

You need to generate multiple outputs for each prompt and rank them by F1 score:

```python
from f1_score_utils import compute_file_level_f1, rank_outputs_by_f1

# Rank outputs
ranked = rank_outputs_by_f1(outputs, ground_truth)
for output, f1, metrics in ranked:
    print(f"F1: {f1:.3f} - {metrics['true_positives']} correct files")
```

### 2. Configure Training

Edit `config_dpo.yaml`:
- Set `model.repo_id` to your SFT model path
- Adjust `dpo.beta` (temperature parameter, default 0.1)
- Set `dpo.loss_type` (sigmoid, hinge, ipo, kto)
- Configure training hyperparameters

### 3. Run Training

```bash
python run_dpo.py --config config_dpo.yaml
```

### 4. Merge Adapter (Optional)

If training is complete and you want to merge the adapter:

```bash
python run_dpo.py --config config_dpo.yaml --merge-only
```

## Configuration

### DPO Parameters

- `beta`: Temperature for DPO loss (higher = less aggressive preference learning)
- `label_smoothing`: Smoothing factor for labels
- `loss_type`: Type of loss function
  - `sigmoid`: Standard DPO loss (default)
  - `hinge`: Margin-based loss
  - `ipo`: Identity Policy Optimization
  - `kto`: Kahneman-Tversky Optimization
- `use_reference_model`: Whether to use a frozen reference model

### Training Tips

1. **Learning Rate**: Use lower LR than SFT (e.g., 5e-5 vs 2e-4)
2. **Beta**: Start with 0.1, increase for less aggressive learning
3. **Batch Size**: Larger batches are more stable
4. **Data Quality**: Ensure significant F1 difference between chosen/rejected (≥0.1)

## Output

Training outputs:
- `runs/dpo_run_14b_v1/checkpoints/` - Training checkpoints
- `runs/dpo_run_14b_v1/best_adapter/` - Best adapter weights
- `runs/dpo_run_14b_v1/merged_14b_dpo_lora/` - Merged model
- `runs/dpo_run_14b_v1/logs/` - Training logs (JSONL format)

## WandB Integration

Enable experiment tracking in `config_dpo.yaml`:

```yaml
wandb:
  enabled: true
  project: "dpo-training"
  tags: ["dpo-lora", "preference-optimization"]
```

## Example: Generate DPO Data

```python
import json
from f1_score_utils import compute_file_level_f1, create_dpo_pairs_from_generations

# Load SFT data
with open("instruct_data.jsonl") as f:
    for line in f:
        data = json.loads(line)
        prompt = data["input"]
        ground_truth = data["output"]
        
        # Generate multiple outputs with your model
        generations = generate_multiple_outputs(prompt, num_samples=4)
        
        # Create preference pairs
        pairs = create_dpo_pairs_from_generations(
            prompt, generations, ground_truth, min_f1_difference=0.1
        )
        
        # Save pairs
        with open("dpo_dataset.jsonl", "a") as out:
            for pair in pairs:
                out.write(json.dumps(pair) + "\n")
```

## Troubleshooting

1. **OOM Errors**: Reduce batch size or enable gradient checkpointing
2. **No Improvement**: Check F1 score differences in data, increase beta
3. **Unstable Training**: Lower learning rate, increase warmup ratio
4. **Reference Model Issues**: Set `use_reference_model: false` to use implicit reference

## References

- DPO Paper: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
- TRL Library: [HuggingFace TRL](https://github.com/huggingface/trl)