| """ |
| Quick script to convert SFT data to DPO format for training. |
| Since we don't have multiple model generations, we'll create synthetic pairs |
| by using the ground truth as "chosen" and creating degraded versions as "rejected". |
| """ |
|
|
| import json |
| import random |
| from pathlib import Path |
| import sys |
|
|
| sys.path.append(str(Path(__file__).parent)) |
| from f1_score_utils import compute_file_level_f1 |
|
|
|
|
| def degrade_output(output: str) -> str: |
| """ |
| Create a degraded version of the output by: |
| 1. Removing some file selections |
| 2. Adding incorrect file selections |
| 3. Keeping the explanation but modifying selections |
| """ |
| |
| if "##SELECT" not in output: |
| return output |
| |
| parts = output.split("##SELECT") |
| explanation = parts[0] |
| select_section = parts[1].split("<EOS>")[0] if "<EOS>" in parts[1] else parts[1] |
| |
| |
| lines = [l.strip() for l in select_section.strip().split('\n') if l.strip()] |
| |
| if len(lines) <= 1: |
| return output |
| |
| |
| strategy = random.choice(['remove', 'add', 'replace']) |
| |
| if strategy == 'remove' and len(lines) > 1: |
| |
| num_to_remove = min(random.randint(1, 2), len(lines) - 1) |
| new_lines = random.sample(lines, len(lines) - num_to_remove) |
| elif strategy == 'add': |
| |
| fake_files = [ |
| "crates/router/src/handlers/utils.rs::helper_function", |
| "crates/api_models/src/types.rs::RequestType", |
| "crates/common_utils/src/helpers.rs::parse_data", |
| "crates/diesel_models/src/schema.rs::table_definition", |
| ] |
| new_lines = lines + [random.choice(fake_files)] |
| else: |
| |
| if len(lines) > 0: |
| idx = random.randint(0, len(lines) - 1) |
| fake_files = [ |
| "crates/router/src/handlers/utils.rs::helper_function", |
| "crates/api_models/src/types.rs::RequestType", |
| "crates/common_utils/src/helpers.rs::parse_data", |
| ] |
| new_lines = lines.copy() |
| new_lines[idx] = random.choice(fake_files) |
| else: |
| new_lines = lines |
| |
| |
| new_select = "\n".join(new_lines) |
| return f"{explanation}##SELECT\n{new_select}\n<EOS>" |
|
|
|
|
| def create_dpo_pairs(input_jsonl: str, output_jsonl: str, max_examples: int = None): |
| """ |
| Convert SFT data to DPO format by creating synthetic degraded versions. |
| """ |
| pairs_created = 0 |
| examples_processed = 0 |
| |
| with open(input_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out: |
| for line in f_in: |
| if max_examples and examples_processed >= max_examples: |
| break |
| |
| try: |
| data = json.loads(line) |
| except: |
| continue |
| |
| prompt = data.get("input", "") |
| ground_truth = data.get("output", "") |
| |
| if not prompt or not ground_truth or "##SELECT" not in ground_truth: |
| continue |
| |
| |
| num_degraded = random.randint(2, 3) |
| for _ in range(num_degraded): |
| degraded = degrade_output(ground_truth) |
| |
| |
| gt_metrics = compute_file_level_f1(ground_truth, ground_truth) |
| deg_metrics = compute_file_level_f1(degraded, ground_truth) |
| |
| |
| if gt_metrics["f1"] - deg_metrics["f1"] >= 0.1: |
| pair = { |
| "prompt": prompt, |
| "chosen": ground_truth, |
| "rejected": degraded, |
| "chosen_f1": gt_metrics["f1"], |
| "rejected_f1": deg_metrics["f1"] |
| } |
| f_out.write(json.dumps(pair) + '\n') |
| pairs_created += 1 |
| |
| examples_processed += 1 |
| if examples_processed % 100 == 0: |
| print(f"Processed {examples_processed} examples, created {pairs_created} pairs") |
| |
| print(f"\nDone! Processed {examples_processed} examples") |
| print(f"Created {pairs_created} DPO pairs") |
| print(f"Average pairs per example: {pairs_created / max(examples_processed, 1):.2f}") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input", default="../../sft_output.jsonl") |
| parser.add_argument("--output", default="dpo_pairs_generated.jsonl") |
| parser.add_argument("--max-examples", type=int, default=None) |
| args = parser.parse_args() |
| |
| print(f"Converting {args.input} to DPO format...") |
| create_dpo_pairs(args.input, args.output, args.max_examples) |
| print(f"Output saved to: {args.output}") |
|
|