| """ |
| Fast GRPO Data Generation using vLLM API Server |
| |
| This script makes parallel async API calls to a vLLM server with semaphore-based |
| concurrency control for optimal throughput. |
| |
| Usage: |
| # First, start vLLM server in another terminal: |
| bash start_vllm_server.sh |
| |
| # Then run this script: |
| python prepare_grpo_data_api.py \ |
| --sft_dataset /workspace/rl_dataset.jsonl \ |
| --output_dataset grpo_dataset.jsonl \ |
| --num_completions 6 \ |
| --max_concurrent 50 \ |
| --api_url http://localhost:8000/v1 |
| """ |
|
|
| import json |
| import argparse |
| import asyncio |
| from pathlib import Path |
| from typing import List, Dict |
| from tqdm.asyncio import tqdm_asyncio |
| from collections import Counter |
| import re |
| import aiohttp |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class GenerationRequest: |
| """Single generation request""" |
| prompt: str |
| reference: str |
| metadata: Dict |
|
|
|
|
| def compute_f1_score(prediction: str, reference: str) -> float: |
| """Compute token-level F1 score""" |
| def tokenize(text): |
| tokens = re.findall(r'\w+|[^\w\s]', text.lower()) |
| return Counter(tokens) |
| |
| pred_tokens = tokenize(prediction) |
| ref_tokens = tokenize(reference) |
| |
| if not pred_tokens or not ref_tokens: |
| return 0.0 |
| |
| overlap = sum((pred_tokens & ref_tokens).values()) |
| precision = overlap / sum(pred_tokens.values()) if pred_tokens else 0 |
| recall = overlap / sum(ref_tokens.values()) if ref_tokens else 0 |
| |
| if precision + recall == 0: |
| return 0.0 |
| |
| return 2 * (precision * recall) / (precision + recall) |
|
|
|
|
| def load_sft_dataset(path: str, max_samples: int = None) -> List[Dict]: |
| """Load SFT dataset""" |
| samples = [] |
| with open(path, 'r') as f: |
| for i, line in enumerate(f): |
| if max_samples and i >= max_samples: |
| break |
| data = json.loads(line) |
| samples.append(data) |
| return samples |
|
|
|
|
| def format_prompt(instruction: str, input_text: str) -> str: |
| """Format prompt using custom template with proper system prompt""" |
| system_prompt = """You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task. |
| |
| ## Output Format |
| |
| ##OUTPUT |
| Explain the data flow and why each component must change: |
| - Flow: [Input → Processing → Output with arrows] |
| - For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]" |
| - Explain coupling between components |
| |
| ##SELECT |
| modify::crates/path/to/file.rs::impl::ComponentName |
| add::crates/another/file.rs::function::AnotherComponent |
| <EOS> |
| |
| ## Rules |
| |
| 1. Use full paths: `remove::crates/folder/file.rs::Type::Name` |
| 2. Use `::` for nested items: `status::StructName::Type::Name` |
| 3. Always explain "must change because" and "without this" |
| 3. Types of components: function, struct, enum, impl, trait |
| 4. If there is extra information (e.g., enum variants), include that too. |
| 5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS> |
| |
| ## Example |
| |
| ##TASK |
| Add webhook subscription support |
| |
| ##OUTPUT |
| The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings. |
| |
| ##SELECT |
| crates/common_enums/src/enums.rs::EventClass |
| crates/common_enums/src/transformers.rs::SubscriptionStatus |
| <EOS>""" |
| |
| return f"##INSTRUCTION\n{system_prompt}<|im_end|>\n##TASK\n{input_text}<|im_end|>\n" |
|
|
|
|
| async def generate_completions_api( |
| session: aiohttp.ClientSession, |
| api_url: str, |
| prompt: str, |
| num_completions: int, |
| temperature: float, |
| max_tokens: int, |
| top_p: float = 0.95, |
| semaphore: asyncio.Semaphore = None, |
| ) -> List[str]: |
| """ |
| Generate completions via vLLM API with semaphore control. |
| |
| Args: |
| session: aiohttp session |
| api_url: vLLM API base URL |
| prompt: Input prompt |
| num_completions: Number of completions to generate |
| temperature: Sampling temperature |
| max_tokens: Max tokens per completion |
| top_p: Nucleus sampling |
| semaphore: Semaphore for concurrency control |
| |
| Returns: |
| List of generated completions |
| """ |
| if semaphore: |
| async with semaphore: |
| return await _generate_completions_api( |
| session, api_url, prompt, num_completions, |
| temperature, max_tokens, top_p |
| ) |
| else: |
| return await _generate_completions_api( |
| session, api_url, prompt, num_completions, |
| temperature, max_tokens, top_p |
| ) |
|
|
|
|
| async def _generate_completions_api( |
| session: aiohttp.ClientSession, |
| api_url: str, |
| prompt: str, |
| num_completions: int, |
| temperature: float, |
| max_tokens: int, |
| top_p: float, |
| ) -> List[str]: |
| """Internal function to make API call""" |
| url = f"{api_url}/completions" |
| |
| payload = { |
| "model": "qwen2.5-coder-14b", |
| "prompt": prompt, |
| "n": num_completions, |
| "temperature": temperature, |
| "max_tokens": max_tokens, |
| "top_p": top_p, |
| "stop": ["<EOS>", "<|im_end|>"], |
| } |
| |
| try: |
| async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=300)) as response: |
| if response.status != 200: |
| error_text = await response.text() |
| print(f"API Error {response.status}: {error_text}") |
| return [""] * num_completions |
| |
| result = await response.json() |
| completions = [choice["text"].strip() for choice in result["choices"]] |
| return completions |
| except asyncio.TimeoutError: |
| print(f"Timeout for prompt (length={len(prompt)})") |
| return [""] * num_completions |
| except Exception as e: |
| print(f"Error generating completions: {e}") |
| return [""] * num_completions |
|
|
|
|
| async def process_sample( |
| session: aiohttp.ClientSession, |
| api_url: str, |
| sample: Dict, |
| num_completions: int, |
| temperature: float, |
| max_tokens: int, |
| semaphore: asyncio.Semaphore, |
| ) -> Dict: |
| """Process a single sample: generate completions and compute scores""" |
| instruction = sample.get('instruction', 'You are a helpful assistant.') |
| input_text = sample.get('input', '') |
| output_text = sample.get('output', '') |
| |
| prompt = format_prompt(instruction, input_text) |
| |
| |
| model_completions = await generate_completions_api( |
| session=session, |
| api_url=api_url, |
| prompt=prompt, |
| num_completions=num_completions - 1, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| semaphore=semaphore, |
| ) |
| |
| |
| |
| completions = [output_text] + model_completions |
| |
| |
| scores = [1.0] + [compute_f1_score(comp, output_text) for comp in model_completions] |
| |
| return { |
| "prompt": prompt, |
| "completions": completions, |
| "scores": scores, |
| } |
|
|
|
|
| async def process_batch( |
| session: aiohttp.ClientSession, |
| api_url: str, |
| samples: List[Dict], |
| num_completions: int, |
| temperature: float, |
| max_tokens: int, |
| semaphore: asyncio.Semaphore, |
| output_file, |
| ) -> List[Dict]: |
| """Process a batch of samples concurrently""" |
| tasks = [ |
| process_sample( |
| session, api_url, sample, num_completions, |
| temperature, max_tokens, semaphore |
| ) |
| for sample in samples |
| ] |
| |
| results = await tqdm_asyncio.gather(*tasks, desc="Processing batch") |
| |
| |
| for result in results: |
| output_file.write(json.dumps(result) + '\n') |
| output_file.flush() |
| |
| return results |
|
|
|
|
| async def main_async(args): |
| """Main async function""" |
| print(f"\n{'='*60}") |
| print("GRPO Data Generation via vLLM API") |
| print(f"{'='*60}") |
| print(f"API URL: {args.api_url}") |
| print(f"Input dataset: {args.sft_dataset}") |
| print(f"Output dataset: {args.output_dataset}") |
| print(f"Completions per prompt: {args.num_completions}") |
| print(f"Max concurrent requests: {args.max_concurrent}") |
| print(f"Temperature: {args.temperature}") |
| print(f"Max samples: {args.max_samples or 'All'}") |
| print(f"{'='*60}\n") |
| |
| |
| print("Loading SFT dataset...") |
| samples = load_sft_dataset(args.sft_dataset, args.max_samples) |
| print(f"Loaded {len(samples)} samples\n") |
| |
| |
| semaphore = asyncio.Semaphore(args.max_concurrent) |
| |
| |
| all_results = [] |
| |
| async with aiohttp.ClientSession() as session: |
| with open(args.output_dataset, 'w') as f_out: |
| |
| batch_size = args.batch_size |
| num_batches = (len(samples) + batch_size - 1) // batch_size |
| |
| for batch_idx in range(num_batches): |
| batch_start = batch_idx * batch_size |
| batch_end = min(batch_start + batch_size, len(samples)) |
| batch_samples = samples[batch_start:batch_end] |
| |
| print(f"\nBatch {batch_idx + 1}/{num_batches} ({len(batch_samples)} samples)") |
| |
| results = await process_batch( |
| session=session, |
| api_url=args.api_url, |
| samples=batch_samples, |
| num_completions=args.num_completions, |
| temperature=args.temperature, |
| max_tokens=args.max_tokens, |
| semaphore=semaphore, |
| output_file=f_out, |
| ) |
| |
| all_results.extend(results) |
| |
| |
| print(f"\n{'='*60}") |
| print("Generation Complete!") |
| print(f"{'='*60}") |
| print(f"Generated {len(all_results)} GRPO samples") |
| print(f"Output saved to: {args.output_dataset}") |
| |
| |
| all_scores = [score for sample in all_results for score in sample['scores']] |
| if all_scores: |
| avg_score = sum(all_scores) / len(all_scores) |
| max_score = max(all_scores) |
| min_score = min(all_scores) |
| |
| print(f"\nScore Statistics:") |
| print(f" Average F1: {avg_score:.3f}") |
| print(f" Max F1: {max_score:.3f}") |
| print(f" Min F1: {min_score:.3f}") |
| print(f"\n{'='*60}\n") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate GRPO dataset using vLLM API") |
| parser.add_argument("--sft_dataset", type=str, required=True, help="Path to SFT dataset") |
| parser.add_argument("--output_dataset", type=str, required=True, help="Output GRPO dataset path") |
| parser.add_argument("--api_url", type=str, default="http://localhost:8000/v1", help="vLLM API URL") |
| parser.add_argument("--num_completions", type=int, default=6, help="Completions per prompt") |
| parser.add_argument("--max_concurrent", type=int, default=50, help="Max concurrent API requests") |
| parser.add_argument("--batch_size", type=int, default=100, help="Batch size for progress tracking") |
| parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") |
| parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens per completion") |
| parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process") |
| |
| args = parser.parse_args() |
| |
| |
| asyncio.run(main_async(args)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|