|
|
""" |
|
|
Data preparation utilities for converting SFT data to DPO/GRPO formats. |
|
|
This script helps generate multiple outputs and create preference/ranking datasets. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import List, Dict |
|
|
from f1_score_utils import ( |
|
|
compute_file_level_f1, |
|
|
rank_outputs_by_f1, |
|
|
create_dpo_pairs_from_generations |
|
|
) |
|
|
|
|
|
|
|
|
def load_model_for_generation(model_path: str): |
|
|
""" |
|
|
Load a model for generation. This is a placeholder - implement based on your setup. |
|
|
""" |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
print(f"Loading model from {model_path}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def generate_multiple_outputs( |
|
|
model, |
|
|
tokenizer, |
|
|
prompt: str, |
|
|
num_samples: int = 4, |
|
|
temperatures: List[float] = None, |
|
|
max_new_tokens: int = 512 |
|
|
) -> List[str]: |
|
|
""" |
|
|
Generate multiple outputs for a single prompt using different temperatures. |
|
|
""" |
|
|
if temperatures is None: |
|
|
temperatures = [0.6, 0.8, 1.0, 1.2][:num_samples] |
|
|
|
|
|
outputs = [] |
|
|
for temp in temperatures: |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
generated = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temp, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
output_text = tokenizer.decode( |
|
|
generated[0][inputs.input_ids.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
outputs.append(output_text) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def convert_sft_to_dpo( |
|
|
sft_jsonl: str, |
|
|
output_jsonl: str, |
|
|
model_path: str = None, |
|
|
num_samples: int = 4, |
|
|
min_f1_difference: float = 0.1, |
|
|
max_examples: int = None |
|
|
): |
|
|
""" |
|
|
Convert SFT dataset to DPO format by generating multiple outputs and creating pairs. |
|
|
|
|
|
Args: |
|
|
sft_jsonl: Path to SFT JSONL file |
|
|
output_jsonl: Path to output DPO JSONL file |
|
|
model_path: Path to model for generation (if None, you need pre-generated outputs) |
|
|
num_samples: Number of outputs to generate per prompt |
|
|
min_f1_difference: Minimum F1 difference to create a pair |
|
|
max_examples: Maximum number of examples to process (None = all) |
|
|
""" |
|
|
if model_path: |
|
|
model, tokenizer = load_model_for_generation(model_path) |
|
|
else: |
|
|
print("Warning: No model path provided. Expecting pre-generated outputs in data.") |
|
|
model, tokenizer = None, None |
|
|
|
|
|
pairs_created = 0 |
|
|
examples_processed = 0 |
|
|
|
|
|
with open(sft_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out: |
|
|
for line in f_in: |
|
|
if max_examples and examples_processed >= max_examples: |
|
|
break |
|
|
|
|
|
data = json.loads(line) |
|
|
prompt = data.get("input", "") |
|
|
ground_truth = data.get("output", "") |
|
|
|
|
|
if not prompt or not ground_truth: |
|
|
continue |
|
|
|
|
|
|
|
|
if model and tokenizer: |
|
|
try: |
|
|
generations = generate_multiple_outputs( |
|
|
model, tokenizer, prompt, num_samples |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error generating outputs: {e}") |
|
|
continue |
|
|
else: |
|
|
|
|
|
generations = data.get("outputs", []) |
|
|
if len(generations) < 2: |
|
|
print(f"Skipping example: need at least 2 outputs") |
|
|
continue |
|
|
|
|
|
|
|
|
pairs = create_dpo_pairs_from_generations( |
|
|
prompt, generations, ground_truth, min_f1_difference |
|
|
) |
|
|
|
|
|
|
|
|
for pair in pairs: |
|
|
f_out.write(json.dumps(pair) + '\n') |
|
|
pairs_created += 1 |
|
|
|
|
|
examples_processed += 1 |
|
|
if examples_processed % 10 == 0: |
|
|
print(f"Processed {examples_processed} examples, created {pairs_created} pairs") |
|
|
|
|
|
print(f"\nDone! Processed {examples_processed} examples, created {pairs_created} DPO pairs") |
|
|
print(f"Output saved to: {output_jsonl}") |
|
|
|
|
|
|
|
|
def convert_sft_to_grpo( |
|
|
sft_jsonl: str, |
|
|
output_jsonl: str, |
|
|
model_path: str = None, |
|
|
num_samples: int = 4, |
|
|
max_examples: int = None |
|
|
): |
|
|
""" |
|
|
Convert SFT dataset to GRPO format by generating multiple outputs and computing scores. |
|
|
|
|
|
Args: |
|
|
sft_jsonl: Path to SFT JSONL file |
|
|
output_jsonl: Path to output GRPO JSONL file |
|
|
model_path: Path to model for generation |
|
|
num_samples: Number of outputs to generate per prompt |
|
|
max_examples: Maximum number of examples to process (None = all) |
|
|
""" |
|
|
if model_path: |
|
|
model, tokenizer = load_model_for_generation(model_path) |
|
|
else: |
|
|
print("Warning: No model path provided. Expecting pre-generated outputs in data.") |
|
|
model, tokenizer = None, None |
|
|
|
|
|
examples_created = 0 |
|
|
examples_processed = 0 |
|
|
|
|
|
with open(sft_jsonl, 'r') as f_in, open(output_jsonl, 'w') as f_out: |
|
|
for line in f_in: |
|
|
if max_examples and examples_processed >= max_examples: |
|
|
break |
|
|
|
|
|
data = json.loads(line) |
|
|
prompt = data.get("input", "") |
|
|
ground_truth = data.get("output", "") |
|
|
|
|
|
if not prompt or not ground_truth: |
|
|
continue |
|
|
|
|
|
|
|
|
if model and tokenizer: |
|
|
try: |
|
|
generations = generate_multiple_outputs( |
|
|
model, tokenizer, prompt, num_samples |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error generating outputs: {e}") |
|
|
continue |
|
|
else: |
|
|
|
|
|
generations = data.get("outputs", []) |
|
|
if len(generations) < 2: |
|
|
print(f"Skipping example: need at least 2 outputs") |
|
|
continue |
|
|
|
|
|
|
|
|
scores = [] |
|
|
for generation in generations: |
|
|
metrics = compute_file_level_f1(generation, ground_truth) |
|
|
scores.append(metrics["f1"]) |
|
|
|
|
|
|
|
|
grpo_example = { |
|
|
"prompt": prompt, |
|
|
"completions": generations, |
|
|
"scores": scores |
|
|
} |
|
|
|
|
|
f_out.write(json.dumps(grpo_example) + '\n') |
|
|
examples_created += 1 |
|
|
examples_processed += 1 |
|
|
|
|
|
if examples_processed % 10 == 0: |
|
|
print(f"Processed {examples_processed} examples") |
|
|
print(f" Last example F1 scores: {[f'{s:.3f}' for s in scores]}") |
|
|
|
|
|
print(f"\nDone! Created {examples_created} GRPO examples from {examples_processed} SFT examples") |
|
|
print(f"Output saved to: {output_jsonl}") |
|
|
|
|
|
|
|
|
def analyze_dataset(jsonl_path: str, dataset_type: str = "auto"): |
|
|
""" |
|
|
Analyze a dataset and print statistics. |
|
|
|
|
|
Args: |
|
|
jsonl_path: Path to JSONL file |
|
|
dataset_type: "sft", "dpo", "grpo", or "auto" (auto-detect) |
|
|
""" |
|
|
with open(jsonl_path, 'r') as f: |
|
|
lines = f.readlines() |
|
|
|
|
|
if not lines: |
|
|
print("Empty dataset") |
|
|
return |
|
|
|
|
|
first = json.loads(lines[0]) |
|
|
|
|
|
|
|
|
if dataset_type == "auto": |
|
|
if "chosen" in first and "rejected" in first: |
|
|
dataset_type = "dpo" |
|
|
elif "completions" in first and "scores" in first: |
|
|
dataset_type = "grpo" |
|
|
else: |
|
|
dataset_type = "sft" |
|
|
|
|
|
print(f"\nDataset Analysis: {jsonl_path}") |
|
|
print(f"Type: {dataset_type.upper()}") |
|
|
print(f"Total examples: {len(lines)}") |
|
|
|
|
|
if dataset_type == "dpo": |
|
|
f1_diffs = [] |
|
|
for line in lines: |
|
|
data = json.loads(line) |
|
|
chosen_f1 = data.get("chosen_f1", 1.0) |
|
|
rejected_f1 = data.get("rejected_f1", 0.0) |
|
|
f1_diffs.append(chosen_f1 - rejected_f1) |
|
|
|
|
|
print(f"Average F1 difference: {sum(f1_diffs) / len(f1_diffs):.3f}") |
|
|
print(f"Min F1 difference: {min(f1_diffs):.3f}") |
|
|
print(f"Max F1 difference: {max(f1_diffs):.3f}") |
|
|
|
|
|
elif dataset_type == "grpo": |
|
|
all_scores = [] |
|
|
completion_counts = [] |
|
|
for line in lines: |
|
|
data = json.loads(line) |
|
|
scores = data.get("scores", []) |
|
|
all_scores.extend(scores) |
|
|
completion_counts.append(len(scores)) |
|
|
|
|
|
print(f"Average completions per prompt: {sum(completion_counts) / len(completion_counts):.1f}") |
|
|
print(f"Min completions: {min(completion_counts)}") |
|
|
print(f"Max completions: {max(completion_counts)}") |
|
|
print(f"Average F1 score: {sum(all_scores) / len(all_scores):.3f}") |
|
|
print(f"F1 score range: [{min(all_scores):.3f}, {max(all_scores):.3f}]") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Convert SFT data to DPO/GRPO formats") |
|
|
parser.add_argument("--input", required=True, help="Input SFT JSONL file") |
|
|
parser.add_argument("--output", required=True, help="Output JSONL file") |
|
|
parser.add_argument("--format", choices=["dpo", "grpo"], required=True, |
|
|
help="Output format") |
|
|
parser.add_argument("--model", default=None, |
|
|
help="Path to model for generation (optional)") |
|
|
parser.add_argument("--num-samples", type=int, default=4, |
|
|
help="Number of outputs to generate per prompt") |
|
|
parser.add_argument("--max-examples", type=int, default=None, |
|
|
help="Maximum number of examples to process") |
|
|
parser.add_argument("--min-f1-diff", type=float, default=0.1, |
|
|
help="Minimum F1 difference for DPO pairs") |
|
|
parser.add_argument("--analyze", action="store_true", |
|
|
help="Analyze the output dataset after creation") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Converting {args.input} to {args.format.upper()} format...") |
|
|
print(f"Output: {args.output}") |
|
|
|
|
|
if args.format == "dpo": |
|
|
convert_sft_to_dpo( |
|
|
args.input, |
|
|
args.output, |
|
|
args.model, |
|
|
args.num_samples, |
|
|
args.min_f1_diff, |
|
|
args.max_examples |
|
|
) |
|
|
elif args.format == "grpo": |
|
|
convert_sft_to_grpo( |
|
|
args.input, |
|
|
args.output, |
|
|
args.model, |
|
|
args.num_samples, |
|
|
args.max_examples |
|
|
) |
|
|
|
|
|
if args.analyze: |
|
|
analyze_dataset(args.output, args.format) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import sys |
|
|
|
|
|
if len(sys.argv) == 1: |
|
|
print("Data Preparation Utilities") |
|
|
print("=" * 50) |
|
|
print("\nUsage:") |
|
|
print(" python prepare_data.py --input instruct_data.jsonl --output dpo_data.jsonl --format dpo") |
|
|
print(" python prepare_data.py --input instruct_data.jsonl --output grpo_data.jsonl --format grpo") |
|
|
print("\nWith model generation:") |
|
|
print(" python prepare_data.py --input instruct_data.jsonl --output dpo_data.jsonl --format dpo \\") |
|
|
print(" --model ./runs/instruct_run_14b_v1/merged_14b_instruct_lora --num-samples 4") |
|
|
print("\nAnalyze dataset:") |
|
|
print(" python prepare_data.py --input dpo_data.jsonl --output /dev/null --format dpo --analyze") |
|
|
sys.exit(0) |
|
|
|
|
|
main() |
|
|
|