task2file-llm / trainer-kit /DPO /create_synthetic_pairs.py
SirajRLX's picture
Upload folder using huggingface_hub
4eae728 verified
"""
Quick script to convert SFT data to DPO format for training.
Since we don't have multiple model generations, we'll create synthetic pairs
by using the ground truth as "chosen" and creating degraded versions as "rejected".
"""
import json
import random
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent))
from f1_score_utils import compute_file_level_f1
def degrade_output(output: str) -> str:
"""
Create a degraded version of the output by:
1. Removing some file selections
2. Adding incorrect file selections
3. Keeping the explanation but modifying selections
"""
# Split into OUTPUT and SELECT sections
if "##SELECT" not in output:
return output
parts = output.split("##SELECT")
explanation = parts[0]
select_section = parts[1].split("<EOS>")[0] if "<EOS>" in parts[1] else parts[1]
# Extract file selections
lines = [l.strip() for l in select_section.strip().split('\n') if l.strip()]
if len(lines) <= 1:
return output # Can't degrade further
# Strategy: randomly remove 1-2 files OR add a random incorrect file
strategy = random.choice(['remove', 'add', 'replace'])
if strategy == 'remove' and len(lines) > 1:
# Remove 1-2 files
num_to_remove = min(random.randint(1, 2), len(lines) - 1)
new_lines = random.sample(lines, len(lines) - num_to_remove)
elif strategy == 'add':
# Add an incorrect file
fake_files = [
"crates/router/src/handlers/utils.rs::helper_function",
"crates/api_models/src/types.rs::RequestType",
"crates/common_utils/src/helpers.rs::parse_data",
"crates/diesel_models/src/schema.rs::table_definition",
]
new_lines = lines + [random.choice(fake_files)]
else: # replace
# Replace one file with incorrect one
if len(lines) > 0:
idx = random.randint(0, len(lines) - 1)
fake_files = [
"crates/router/src/handlers/utils.rs::helper_function",
"crates/api_models/src/types.rs::RequestType",
"crates/common_utils/src/helpers.rs::parse_data",
]
new_lines = lines.copy()
new_lines[idx] = random.choice(fake_files)
else:
new_lines = lines
# Reconstruct output
new_select = "\n".join(new_lines)
return f"{explanation}##SELECT\n{new_select}\n<EOS>"
def create_dpo_pairs(input_jsonl: str, output_jsonl: str, max_examples: int = None):
"""
Convert SFT data to DPO format by creating synthetic degraded versions.
"""
pairs_created = 0
examples_processed = 0
with open(input_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out:
for line in f_in:
if max_examples and examples_processed >= max_examples:
break
try:
data = json.loads(line)
except:
continue
prompt = data.get("input", "")
ground_truth = data.get("output", "")
if not prompt or not ground_truth or "##SELECT" not in ground_truth:
continue
# Create 2-3 degraded versions
num_degraded = random.randint(2, 3)
for _ in range(num_degraded):
degraded = degrade_output(ground_truth)
# Compute F1 scores
gt_metrics = compute_file_level_f1(ground_truth, ground_truth)
deg_metrics = compute_file_level_f1(degraded, ground_truth)
# Only create pair if there's a significant difference
if gt_metrics["f1"] - deg_metrics["f1"] >= 0.1:
pair = {
"prompt": prompt,
"chosen": ground_truth,
"rejected": degraded,
"chosen_f1": gt_metrics["f1"],
"rejected_f1": deg_metrics["f1"]
}
f_out.write(json.dumps(pair) + '\n')
pairs_created += 1
examples_processed += 1
if examples_processed % 100 == 0:
print(f"Processed {examples_processed} examples, created {pairs_created} pairs")
print(f"\nDone! Processed {examples_processed} examples")
print(f"Created {pairs_created} DPO pairs")
print(f"Average pairs per example: {pairs_created / max(examples_processed, 1):.2f}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="../../sft_output.jsonl")
parser.add_argument("--output", default="dpo_pairs_generated.jsonl")
parser.add_argument("--max-examples", type=int, default=None)
args = parser.parse_args()
print(f"Converting {args.input} to DPO format...")
create_dpo_pairs(args.input, args.output, args.max_examples)
print(f"Output saved to: {args.output}")