task2file-llm / trainer-kit /DPO /prepare_data.py
SirajRLX's picture
Upload folder using huggingface_hub
4eae728 verified
"""
Data preparation utilities for converting SFT data to DPO/GRPO formats.
This script helps generate multiple outputs and create preference/ranking datasets.
"""
import json
import argparse
from pathlib import Path
from typing import List, Dict
from f1_score_utils import (
compute_file_level_f1,
rank_outputs_by_f1,
create_dpo_pairs_from_generations
)
def load_model_for_generation(model_path: str):
"""
Load a model for generation. This is a placeholder - implement based on your setup.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
print(f"Loading model from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
return model, tokenizer
def generate_multiple_outputs(
model,
tokenizer,
prompt: str,
num_samples: int = 4,
temperatures: List[float] = None,
max_new_tokens: int = 512
) -> List[str]:
"""
Generate multiple outputs for a single prompt using different temperatures.
"""
if temperatures is None:
temperatures = [0.6, 0.8, 1.0, 1.2][:num_samples]
outputs = []
for temp in temperatures:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generated = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temp,
do_sample=True,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
# Extract only the new tokens (not the prompt)
output_text = tokenizer.decode(
generated[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
outputs.append(output_text)
return outputs
def convert_sft_to_dpo(
sft_jsonl: str,
output_jsonl: str,
model_path: str = None,
num_samples: int = 4,
min_f1_difference: float = 0.1,
max_examples: int = None
):
"""
Convert SFT dataset to DPO format by generating multiple outputs and creating pairs.
Args:
sft_jsonl: Path to SFT JSONL file
output_jsonl: Path to output DPO JSONL file
model_path: Path to model for generation (if None, you need pre-generated outputs)
num_samples: Number of outputs to generate per prompt
min_f1_difference: Minimum F1 difference to create a pair
max_examples: Maximum number of examples to process (None = all)
"""
if model_path:
model, tokenizer = load_model_for_generation(model_path)
else:
print("Warning: No model path provided. Expecting pre-generated outputs in data.")
model, tokenizer = None, None
pairs_created = 0
examples_processed = 0
with open(sft_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out:
for line in f_in:
if max_examples and examples_processed >= max_examples:
break
data = json.loads(line)
prompt = data.get("input", "")
ground_truth = data.get("output", "")
if not prompt or not ground_truth:
continue
# Generate multiple outputs
if model and tokenizer:
try:
generations = generate_multiple_outputs(
model, tokenizer, prompt, num_samples
)
except Exception as e:
print(f"Error generating outputs: {e}")
continue
else:
# Expect outputs in the data
generations = data.get("outputs", [])
if len(generations) < 2:
print(f"Skipping example: need at least 2 outputs")
continue
# Create DPO pairs
pairs = create_dpo_pairs_from_generations(
prompt, generations, ground_truth, min_f1_difference
)
# Write pairs to output
for pair in pairs:
f_out.write(json.dumps(pair) + '\n')
pairs_created += 1
examples_processed += 1
if examples_processed % 10 == 0:
print(f"Processed {examples_processed} examples, created {pairs_created} pairs")
print(f"\nDone! Processed {examples_processed} examples, created {pairs_created} DPO pairs")
print(f"Output saved to: {output_jsonl}")
def convert_sft_to_grpo(
sft_jsonl: str,
output_jsonl: str,
model_path: str = None,
num_samples: int = 4,
max_examples: int = None
):
"""
Convert SFT dataset to GRPO format by generating multiple outputs and computing scores.
Args:
sft_jsonl: Path to SFT JSONL file
output_jsonl: Path to output GRPO JSONL file
model_path: Path to model for generation
num_samples: Number of outputs to generate per prompt
max_examples: Maximum number of examples to process (None = all)
"""
if model_path:
model, tokenizer = load_model_for_generation(model_path)
else:
print("Warning: No model path provided. Expecting pre-generated outputs in data.")
model, tokenizer = None, None
examples_created = 0
examples_processed = 0
with open(sft_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out:
for line in f_in:
if max_examples and examples_processed >= max_examples:
break
data = json.loads(line)
prompt = data.get("input", "")
ground_truth = data.get("output", "")
if not prompt or not ground_truth:
continue
# Generate multiple outputs
if model and tokenizer:
try:
generations = generate_multiple_outputs(
model, tokenizer, prompt, num_samples
)
except Exception as e:
print(f"Error generating outputs: {e}")
continue
else:
# Expect outputs in the data
generations = data.get("outputs", [])
if len(generations) < 2:
print(f"Skipping example: need at least 2 outputs")
continue
# Compute F1 scores for all generations
scores = []
for generation in generations:
metrics = compute_file_level_f1(generation, ground_truth)
scores.append(metrics["f1"])
# Create GRPO example
grpo_example = {
"prompt": prompt,
"completions": generations,
"scores": scores
}
f_out.write(json.dumps(grpo_example) + '\n')
examples_created += 1
examples_processed += 1
if examples_processed % 10 == 0:
print(f"Processed {examples_processed} examples")
print(f" Last example F1 scores: {[f'{s:.3f}' for s in scores]}")
print(f"\nDone! Created {examples_created} GRPO examples from {examples_processed} SFT examples")
print(f"Output saved to: {output_jsonl}")
def analyze_dataset(jsonl_path: str, dataset_type: str = "auto"):
"""
Analyze a dataset and print statistics.
Args:
jsonl_path: Path to JSONL file
dataset_type: "sft", "dpo", "grpo", or "auto" (auto-detect)
"""
with open(jsonl_path, 'r') as f:
lines = f.readlines()
if not lines:
print("Empty dataset")
return
first = json.loads(lines[0])
# Auto-detect type
if dataset_type == "auto":
if "chosen" in first and "rejected" in first:
dataset_type = "dpo"
elif "completions" in first and "scores" in first:
dataset_type = "grpo"
else:
dataset_type = "sft"
print(f"\nDataset Analysis: {jsonl_path}")
print(f"Type: {dataset_type.upper()}")
print(f"Total examples: {len(lines)}")
if dataset_type == "dpo":
f1_diffs = []
for line in lines:
data = json.loads(line)
chosen_f1 = data.get("chosen_f1", 1.0)
rejected_f1 = data.get("rejected_f1", 0.0)
f1_diffs.append(chosen_f1 - rejected_f1)
print(f"Average F1 difference: {sum(f1_diffs) / len(f1_diffs):.3f}")
print(f"Min F1 difference: {min(f1_diffs):.3f}")
print(f"Max F1 difference: {max(f1_diffs):.3f}")
elif dataset_type == "grpo":
all_scores = []
completion_counts = []
for line in lines:
data = json.loads(line)
scores = data.get("scores", [])
all_scores.extend(scores)
completion_counts.append(len(scores))
print(f"Average completions per prompt: {sum(completion_counts) / len(completion_counts):.1f}")
print(f"Min completions: {min(completion_counts)}")
print(f"Max completions: {max(completion_counts)}")
print(f"Average F1 score: {sum(all_scores) / len(all_scores):.3f}")
print(f"F1 score range: [{min(all_scores):.3f}, {max(all_scores):.3f}]")
def main():
parser = argparse.ArgumentParser(description="Convert SFT data to DPO/GRPO formats")
parser.add_argument("--input", required=True, help="Input SFT JSONL file")
parser.add_argument("--output", required=True, help="Output JSONL file")
parser.add_argument("--format", choices=["dpo", "grpo"], required=True,
help="Output format")
parser.add_argument("--model", default=None,
help="Path to model for generation (optional)")
parser.add_argument("--num-samples", type=int, default=4,
help="Number of outputs to generate per prompt")
parser.add_argument("--max-examples", type=int, default=None,
help="Maximum number of examples to process")
parser.add_argument("--min-f1-diff", type=float, default=0.1,
help="Minimum F1 difference for DPO pairs")
parser.add_argument("--analyze", action="store_true",
help="Analyze the output dataset after creation")
args = parser.parse_args()
print(f"Converting {args.input} to {args.format.upper()} format...")
print(f"Output: {args.output}")
if args.format == "dpo":
convert_sft_to_dpo(
args.input,
args.output,
args.model,
args.num_samples,
args.min_f1_diff,
args.max_examples
)
elif args.format == "grpo":
convert_sft_to_grpo(
args.input,
args.output,
args.model,
args.num_samples,
args.max_examples
)
if args.analyze:
analyze_dataset(args.output, args.format)
if __name__ == "__main__":
# Example usage without CLI
import sys
if len(sys.argv) == 1:
print("Data Preparation Utilities")
print("=" * 50)
print("\nUsage:")
print(" python prepare_data.py --input instruct_data.jsonl --output dpo_data.jsonl --format dpo")
print(" python prepare_data.py --input instruct_data.jsonl --output grpo_data.jsonl --format grpo")
print("\nWith model generation:")
print(" python prepare_data.py --input instruct_data.jsonl --output dpo_data.jsonl --format dpo \\")
print(" --model ./runs/instruct_run_14b_v1/merged_14b_instruct_lora --num-samples 4")
print("\nAnalyze dataset:")
print(" python prepare_data.py --input dpo_data.jsonl --output /dev/null --format dpo --analyze")
sys.exit(0)
main()