| """ |
| Fast GRPO Data Generation using vLLM |
| |
| Generates multiple completions per prompt in parallel using vLLM for 10-20x speedup. |
| |
| Usage: |
| pip install vllm |
| python prepare_grpo_data_vllm.py \ |
| --sft_dataset /workspace/rl_dataset.jsonl \ |
| --output_dataset grpo_dataset.jsonl \ |
| --model_path /workspace/Models/Qwen2.5-Coder-14B-CPT-SFT_v2 \ |
| --num_completions 6 \ |
| --batch_size 50 |
| """ |
|
|
| import json |
| import argparse |
| from pathlib import Path |
| from typing import List, Dict |
| from tqdm import tqdm |
| from collections import Counter |
| import re |
|
|
| try: |
| from vllm import LLM, SamplingParams |
| VLLM_AVAILABLE = True |
| except ImportError: |
| VLLM_AVAILABLE = False |
| print("vLLM not available. Install with: pip install vllm") |
|
|
|
|
| def compute_f1_score(prediction: str, reference: str) -> float: |
| """Compute token-level F1 score""" |
| def tokenize(text): |
| |
| tokens = re.findall(r'\w+|[^\w\s]', text.lower()) |
| return Counter(tokens) |
| |
| pred_tokens = tokenize(prediction) |
| ref_tokens = tokenize(reference) |
| |
| if not pred_tokens or not ref_tokens: |
| return 0.0 |
| |
| |
| overlap = sum((pred_tokens & ref_tokens).values()) |
| |
| precision = overlap / sum(pred_tokens.values()) if pred_tokens else 0 |
| recall = overlap / sum(ref_tokens.values()) if ref_tokens else 0 |
| |
| if precision + recall == 0: |
| return 0.0 |
| |
| f1 = 2 * (precision * recall) / (precision + recall) |
| return f1 |
|
|
|
|
| def load_sft_dataset(path: str, max_samples: int = None) -> List[Dict]: |
| """Load SFT dataset""" |
| samples = [] |
| with open(path, 'r') as f: |
| for i, line in enumerate(f): |
| if max_samples and i >= max_samples: |
| break |
| data = json.loads(line) |
| samples.append(data) |
| return samples |
|
|
|
|
| def format_prompt(instruction: str, input_text: str) -> str: |
| """Format prompt using custom template""" |
| return f"##INSTRUCTION\n{instruction}<|im_end|>\n{input_text}<|im_end|>\n" |
|
|
|
|
| def generate_completions_vllm( |
| llm: LLM, |
| prompts: List[str], |
| num_completions: int, |
| temperature: float, |
| max_tokens: int, |
| top_p: float = 0.95, |
| ) -> List[List[str]]: |
| """ |
| Generate multiple completions for multiple prompts using vLLM. |
| |
| Args: |
| llm: vLLM LLM instance |
| prompts: List of prompts |
| num_completions: Number of completions per prompt |
| temperature: Sampling temperature |
| max_tokens: Maximum tokens to generate |
| top_p: Nucleus sampling parameter |
| |
| Returns: |
| List of lists, where each inner list contains completions for one prompt |
| """ |
| sampling_params = SamplingParams( |
| n=num_completions, |
| temperature=temperature, |
| top_p=top_p, |
| max_tokens=max_tokens, |
| stop=["<EOS>", "<|im_end|>"], |
| ) |
| |
| |
| outputs = llm.generate(prompts, sampling_params) |
| |
| |
| all_completions = [] |
| for output in outputs: |
| completions = [o.text.strip() for o in output.outputs] |
| all_completions.append(completions) |
| |
| return all_completions |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate GRPO dataset using vLLM") |
| parser.add_argument("--sft_dataset", type=str, required=True, help="Path to SFT dataset") |
| parser.add_argument("--output_dataset", type=str, required=True, help="Output GRPO dataset path") |
| parser.add_argument("--model_path", type=str, required=True, help="Model path") |
| parser.add_argument("--num_completions", type=int, default=6, help="Completions per prompt") |
| parser.add_argument("--batch_size", type=int, default=50, help="Number of prompts to process at once") |
| parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") |
| parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens per completion") |
| parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process (for testing)") |
| parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism") |
| parser.add_argument("--gpu_memory_utilization", type=float, default=0.85, help="GPU memory utilization (0-1)") |
| |
| args = parser.parse_args() |
| |
| if not VLLM_AVAILABLE: |
| raise ImportError("vLLM is required. Install with: pip install vllm") |
| |
| print(f"\n{'='*60}") |
| print("GRPO Data Generation with vLLM") |
| print(f"{'='*60}") |
| print(f"Model: {args.model_path}") |
| print(f"Input dataset: {args.sft_dataset}") |
| print(f"Output dataset: {args.output_dataset}") |
| print(f"Completions per prompt: {args.num_completions}") |
| print(f"Batch size: {args.batch_size}") |
| print(f"Temperature: {args.temperature}") |
| print(f"Max samples: {args.max_samples or 'All'}") |
| print(f"{'='*60}\n") |
| |
| |
| print("Loading SFT dataset...") |
| sft_samples = load_sft_dataset(args.sft_dataset, args.max_samples) |
| print(f"Loaded {len(sft_samples)} samples\n") |
| |
| |
| print("Initializing vLLM...") |
| llm = LLM( |
| model=args.model_path, |
| tensor_parallel_size=args.tensor_parallel_size, |
| gpu_memory_utilization=args.gpu_memory_utilization, |
| trust_remote_code=True, |
| dtype="bfloat16", |
| max_model_len=4096, |
| ) |
| print("vLLM initialized!\n") |
| |
| |
| grpo_samples = [] |
| num_batches = (len(sft_samples) + args.batch_size - 1) // args.batch_size |
| |
| with open(args.output_dataset, 'w') as f_out: |
| for batch_idx in tqdm(range(num_batches), desc="Processing batches"): |
| batch_start = batch_idx * args.batch_size |
| batch_end = min(batch_start + args.batch_size, len(sft_samples)) |
| batch_samples = sft_samples[batch_start:batch_end] |
| |
| |
| prompts = [] |
| references = [] |
| for sample in batch_samples: |
| instruction = sample.get('instruction', 'You are a helpful assistant.') |
| input_text = sample.get('input', '') |
| output_text = sample.get('output', '') |
| |
| prompt = format_prompt(instruction, input_text) |
| prompts.append(prompt) |
| references.append(output_text) |
| |
| |
| batch_completions = generate_completions_vllm( |
| llm=llm, |
| prompts=prompts, |
| num_completions=args.num_completions, |
| temperature=args.temperature, |
| max_tokens=args.max_tokens, |
| ) |
| |
| |
| for sample, prompt, completions, reference in zip(batch_samples, prompts, batch_completions, references): |
| |
| scores = [compute_f1_score(comp, reference) for comp in completions] |
| |
| |
| grpo_sample = { |
| "prompt": prompt, |
| "completions": completions, |
| "scores": scores, |
| } |
| |
| |
| f_out.write(json.dumps(grpo_sample) + '\n') |
| f_out.flush() |
| grpo_samples.append(grpo_sample) |
| |
| |
| print(f"\n{'='*60}") |
| print("Generation Complete!") |
| print(f"{'='*60}") |
| print(f"Generated {len(grpo_samples)} GRPO samples") |
| print(f"Output saved to: {args.output_dataset}") |
| |
| |
| all_scores = [score for sample in grpo_samples for score in sample['scores']] |
| avg_score = sum(all_scores) / len(all_scores) if all_scores else 0 |
| max_score = max(all_scores) if all_scores else 0 |
| min_score = min(all_scores) if all_scores else 0 |
| |
| print(f"\nScore Statistics:") |
| print(f" Average F1: {avg_score:.3f}") |
| print(f" Max F1: {max_score:.3f}") |
| print(f" Min F1: {min_score:.3f}") |
| print(f"\n{'='*60}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|