task2file-llm / trainer-kit /GRPO-14B /prepare_grpo_data_vllm.py
SirajRLX's picture
Upload folder using huggingface_hub
d6bd954 verified
"""
Fast GRPO Data Generation using vLLM
Generates multiple completions per prompt in parallel using vLLM for 10-20x speedup.
Usage:
pip install vllm
python prepare_grpo_data_vllm.py \
--sft_dataset /workspace/rl_dataset.jsonl \
--output_dataset grpo_dataset.jsonl \
--model_path /workspace/Models/Qwen2.5-Coder-14B-CPT-SFT_v2 \
--num_completions 6 \
--batch_size 50
"""
import json
import argparse
from pathlib import Path
from typing import List, Dict
from tqdm import tqdm
from collections import Counter
import re
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
print("vLLM not available. Install with: pip install vllm")
def compute_f1_score(prediction: str, reference: str) -> float:
"""Compute token-level F1 score"""
def tokenize(text):
# Split on whitespace and punctuation
tokens = re.findall(r'\w+|[^\w\s]', text.lower())
return Counter(tokens)
pred_tokens = tokenize(prediction)
ref_tokens = tokenize(reference)
if not pred_tokens or not ref_tokens:
return 0.0
# Count overlapping tokens
overlap = sum((pred_tokens & ref_tokens).values())
precision = overlap / sum(pred_tokens.values()) if pred_tokens else 0
recall = overlap / sum(ref_tokens.values()) if ref_tokens else 0
if precision + recall == 0:
return 0.0
f1 = 2 * (precision * recall) / (precision + recall)
return f1
def load_sft_dataset(path: str, max_samples: int = None) -> List[Dict]:
"""Load SFT dataset"""
samples = []
with open(path, 'r') as f:
for i, line in enumerate(f):
if max_samples and i >= max_samples:
break
data = json.loads(line)
samples.append(data)
return samples
def format_prompt(instruction: str, input_text: str) -> str:
"""Format prompt using custom template"""
return f"##INSTRUCTION\n{instruction}<|im_end|>\n{input_text}<|im_end|>\n"
def generate_completions_vllm(
llm: LLM,
prompts: List[str],
num_completions: int,
temperature: float,
max_tokens: int,
top_p: float = 0.95,
) -> List[List[str]]:
"""
Generate multiple completions for multiple prompts using vLLM.
Args:
llm: vLLM LLM instance
prompts: List of prompts
num_completions: Number of completions per prompt
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
Returns:
List of lists, where each inner list contains completions for one prompt
"""
sampling_params = SamplingParams(
n=num_completions, # Generate n completions per prompt
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stop=["<EOS>", "<|im_end|>"],
)
# vLLM automatically batches and processes efficiently
outputs = llm.generate(prompts, sampling_params)
# Extract completions
all_completions = []
for output in outputs:
completions = [o.text.strip() for o in output.outputs]
all_completions.append(completions)
return all_completions
def main():
parser = argparse.ArgumentParser(description="Generate GRPO dataset using vLLM")
parser.add_argument("--sft_dataset", type=str, required=True, help="Path to SFT dataset")
parser.add_argument("--output_dataset", type=str, required=True, help="Output GRPO dataset path")
parser.add_argument("--model_path", type=str, required=True, help="Model path")
parser.add_argument("--num_completions", type=int, default=6, help="Completions per prompt")
parser.add_argument("--batch_size", type=int, default=50, help="Number of prompts to process at once")
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens per completion")
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process (for testing)")
parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism")
parser.add_argument("--gpu_memory_utilization", type=float, default=0.85, help="GPU memory utilization (0-1)")
args = parser.parse_args()
if not VLLM_AVAILABLE:
raise ImportError("vLLM is required. Install with: pip install vllm")
print(f"\n{'='*60}")
print("GRPO Data Generation with vLLM")
print(f"{'='*60}")
print(f"Model: {args.model_path}")
print(f"Input dataset: {args.sft_dataset}")
print(f"Output dataset: {args.output_dataset}")
print(f"Completions per prompt: {args.num_completions}")
print(f"Batch size: {args.batch_size}")
print(f"Temperature: {args.temperature}")
print(f"Max samples: {args.max_samples or 'All'}")
print(f"{'='*60}\n")
# Load SFT dataset
print("Loading SFT dataset...")
sft_samples = load_sft_dataset(args.sft_dataset, args.max_samples)
print(f"Loaded {len(sft_samples)} samples\n")
# Initialize vLLM
print("Initializing vLLM...")
llm = LLM(
model=args.model_path,
tensor_parallel_size=args.tensor_parallel_size,
gpu_memory_utilization=args.gpu_memory_utilization,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=4096,
)
print("vLLM initialized!\n")
# Process in batches
grpo_samples = []
num_batches = (len(sft_samples) + args.batch_size - 1) // args.batch_size
with open(args.output_dataset, 'w') as f_out:
for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
batch_start = batch_idx * args.batch_size
batch_end = min(batch_start + args.batch_size, len(sft_samples))
batch_samples = sft_samples[batch_start:batch_end]
# Prepare prompts for this batch
prompts = []
references = []
for sample in batch_samples:
instruction = sample.get('instruction', 'You are a helpful assistant.')
input_text = sample.get('input', '')
output_text = sample.get('output', '')
prompt = format_prompt(instruction, input_text)
prompts.append(prompt)
references.append(output_text)
# Generate completions for entire batch
batch_completions = generate_completions_vllm(
llm=llm,
prompts=prompts,
num_completions=args.num_completions,
temperature=args.temperature,
max_tokens=args.max_tokens,
)
# Compute F1 scores and create GRPO samples
for sample, prompt, completions, reference in zip(batch_samples, prompts, batch_completions, references):
# Compute F1 score for each completion
scores = [compute_f1_score(comp, reference) for comp in completions]
# Create GRPO sample
grpo_sample = {
"prompt": prompt,
"completions": completions,
"scores": scores,
}
# Write immediately (streaming)
f_out.write(json.dumps(grpo_sample) + '\n')
f_out.flush()
grpo_samples.append(grpo_sample)
# Statistics
print(f"\n{'='*60}")
print("Generation Complete!")
print(f"{'='*60}")
print(f"Generated {len(grpo_samples)} GRPO samples")
print(f"Output saved to: {args.output_dataset}")
# Compute statistics
all_scores = [score for sample in grpo_samples for score in sample['scores']]
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0
max_score = max(all_scores) if all_scores else 0
min_score = min(all_scores) if all_scores else 0
print(f"\nScore Statistics:")
print(f" Average F1: {avg_score:.3f}")
print(f" Max F1: {max_score:.3f}")
print(f" Min F1: {min_score:.3f}")
print(f"\n{'='*60}\n")
if __name__ == "__main__":
main()