task2file-llm / trainer-kit /GRPO-14B /prepare_grpo_data_api.py
SirajRLX's picture
Upload folder using huggingface_hub
d6bd954 verified
"""
Fast GRPO Data Generation using vLLM API Server
This script makes parallel async API calls to a vLLM server with semaphore-based
concurrency control for optimal throughput.
Usage:
# First, start vLLM server in another terminal:
bash start_vllm_server.sh
# Then run this script:
python prepare_grpo_data_api.py \
--sft_dataset /workspace/rl_dataset.jsonl \
--output_dataset grpo_dataset.jsonl \
--num_completions 6 \
--max_concurrent 50 \
--api_url http://localhost:8000/v1
"""
import json
import argparse
import asyncio
from pathlib import Path
from typing import List, Dict
from tqdm.asyncio import tqdm_asyncio
from collections import Counter
import re
import aiohttp
from dataclasses import dataclass
@dataclass
class GenerationRequest:
"""Single generation request"""
prompt: str
reference: str
metadata: Dict
def compute_f1_score(prediction: str, reference: str) -> float:
"""Compute token-level F1 score"""
def tokenize(text):
tokens = re.findall(r'\w+|[^\w\s]', text.lower())
return Counter(tokens)
pred_tokens = tokenize(prediction)
ref_tokens = tokenize(reference)
if not pred_tokens or not ref_tokens:
return 0.0
overlap = sum((pred_tokens & ref_tokens).values())
precision = overlap / sum(pred_tokens.values()) if pred_tokens else 0
recall = overlap / sum(ref_tokens.values()) if ref_tokens else 0
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
def load_sft_dataset(path: str, max_samples: int = None) -> List[Dict]:
"""Load SFT dataset"""
samples = []
with open(path, 'r') as f:
for i, line in enumerate(f):
if max_samples and i >= max_samples:
break
data = json.loads(line)
samples.append(data)
return samples
def format_prompt(instruction: str, input_text: str) -> str:
"""Format prompt using custom template with proper system prompt"""
system_prompt = """You are a Hyperswitch Rust code analyzer. Identify functions/structs that need modification for a given task.
## Output Format
##OUTPUT
Explain the data flow and why each component must change:
- Flow: [Input → Processing → Output with arrows]
- For each component: "The [ComponentName] ([path]) must [action] because [reason]—without this, [consequence]"
- Explain coupling between components
##SELECT
modify::crates/path/to/file.rs::impl::ComponentName
add::crates/another/file.rs::function::AnotherComponent
<EOS>
## Rules
1. Use full paths: `remove::crates/folder/file.rs::Type::Name`
2. Use `::` for nested items: `status::StructName::Type::Name`
3. Always explain "must change because" and "without this"
3. Types of components: function, struct, enum, impl, trait
4. If there is extra information (e.g., enum variants), include that too.
5. Start with ##OUTPUT, end with ##SELECT, terminate with <EOS>
## Example
##TASK
Add webhook subscription support
##OUTPUT
The webhook system routes events via EventClass enum. Flow: webhook → EventClass → handler → processing. The EventClass enum (crates/common_enums/src/enums.rs::EventClass) must add Subscriptions variant because it defines event routing—without this, subscription events cannot be processed. The SubscriptionStatus impl (crates/common_enums/src/transformers.rs::SubscriptionStatus) must map to EventType because it converts status to events—without this, status changes don't trigger webhooks. These are coupled: EventClass routes to handlers that use SubscriptionStatus mappings.
##SELECT
crates/common_enums/src/enums.rs::EventClass
crates/common_enums/src/transformers.rs::SubscriptionStatus
<EOS>"""
return f"##INSTRUCTION\n{system_prompt}<|im_end|>\n##TASK\n{input_text}<|im_end|>\n"
async def generate_completions_api(
session: aiohttp.ClientSession,
api_url: str,
prompt: str,
num_completions: int,
temperature: float,
max_tokens: int,
top_p: float = 0.95,
semaphore: asyncio.Semaphore = None,
) -> List[str]:
"""
Generate completions via vLLM API with semaphore control.
Args:
session: aiohttp session
api_url: vLLM API base URL
prompt: Input prompt
num_completions: Number of completions to generate
temperature: Sampling temperature
max_tokens: Max tokens per completion
top_p: Nucleus sampling
semaphore: Semaphore for concurrency control
Returns:
List of generated completions
"""
if semaphore:
async with semaphore:
return await _generate_completions_api(
session, api_url, prompt, num_completions,
temperature, max_tokens, top_p
)
else:
return await _generate_completions_api(
session, api_url, prompt, num_completions,
temperature, max_tokens, top_p
)
async def _generate_completions_api(
session: aiohttp.ClientSession,
api_url: str,
prompt: str,
num_completions: int,
temperature: float,
max_tokens: int,
top_p: float,
) -> List[str]:
"""Internal function to make API call"""
url = f"{api_url}/completions"
payload = {
"model": "qwen2.5-coder-14b", # Match the model name from vLLM server
"prompt": prompt,
"n": num_completions,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stop": ["<EOS>", "<|im_end|>"],
}
try:
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=300)) as response:
if response.status != 200:
error_text = await response.text()
print(f"API Error {response.status}: {error_text}")
return [""] * num_completions
result = await response.json()
completions = [choice["text"].strip() for choice in result["choices"]]
return completions
except asyncio.TimeoutError:
print(f"Timeout for prompt (length={len(prompt)})")
return [""] * num_completions
except Exception as e:
print(f"Error generating completions: {e}")
return [""] * num_completions
async def process_sample(
session: aiohttp.ClientSession,
api_url: str,
sample: Dict,
num_completions: int,
temperature: float,
max_tokens: int,
semaphore: asyncio.Semaphore,
) -> Dict:
"""Process a single sample: generate completions and compute scores"""
instruction = sample.get('instruction', 'You are a helpful assistant.')
input_text = sample.get('input', '')
output_text = sample.get('output', '')
prompt = format_prompt(instruction, input_text)
# Generate model completions (num_completions - 1, since we'll add ground truth)
model_completions = await generate_completions_api(
session=session,
api_url=api_url,
prompt=prompt,
num_completions=num_completions - 1, # Generate one less, we'll add ground truth
temperature=temperature,
max_tokens=max_tokens,
semaphore=semaphore,
)
# Completion 1: Ground truth (score = 1.0)
# Completions 2-N: Model generations (scored via F1)
completions = [output_text] + model_completions
# Compute scores: ground truth = 1.0, others = F1 vs ground truth
scores = [1.0] + [compute_f1_score(comp, output_text) for comp in model_completions]
return {
"prompt": prompt,
"completions": completions,
"scores": scores,
}
async def process_batch(
session: aiohttp.ClientSession,
api_url: str,
samples: List[Dict],
num_completions: int,
temperature: float,
max_tokens: int,
semaphore: asyncio.Semaphore,
output_file,
) -> List[Dict]:
"""Process a batch of samples concurrently"""
tasks = [
process_sample(
session, api_url, sample, num_completions,
temperature, max_tokens, semaphore
)
for sample in samples
]
results = await tqdm_asyncio.gather(*tasks, desc="Processing batch")
# Write results immediately
for result in results:
output_file.write(json.dumps(result) + '\n')
output_file.flush()
return results
async def main_async(args):
"""Main async function"""
print(f"\n{'='*60}")
print("GRPO Data Generation via vLLM API")
print(f"{'='*60}")
print(f"API URL: {args.api_url}")
print(f"Input dataset: {args.sft_dataset}")
print(f"Output dataset: {args.output_dataset}")
print(f"Completions per prompt: {args.num_completions}")
print(f"Max concurrent requests: {args.max_concurrent}")
print(f"Temperature: {args.temperature}")
print(f"Max samples: {args.max_samples or 'All'}")
print(f"{'='*60}\n")
# Load dataset
print("Loading SFT dataset...")
samples = load_sft_dataset(args.sft_dataset, args.max_samples)
print(f"Loaded {len(samples)} samples\n")
# Create semaphore for concurrency control
semaphore = asyncio.Semaphore(args.max_concurrent)
# Process all samples
all_results = []
async with aiohttp.ClientSession() as session:
with open(args.output_dataset, 'w') as f_out:
# Process in batches for better progress tracking
batch_size = args.batch_size
num_batches = (len(samples) + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
batch_start = batch_idx * batch_size
batch_end = min(batch_start + batch_size, len(samples))
batch_samples = samples[batch_start:batch_end]
print(f"\nBatch {batch_idx + 1}/{num_batches} ({len(batch_samples)} samples)")
results = await process_batch(
session=session,
api_url=args.api_url,
samples=batch_samples,
num_completions=args.num_completions,
temperature=args.temperature,
max_tokens=args.max_tokens,
semaphore=semaphore,
output_file=f_out,
)
all_results.extend(results)
# Statistics
print(f"\n{'='*60}")
print("Generation Complete!")
print(f"{'='*60}")
print(f"Generated {len(all_results)} GRPO samples")
print(f"Output saved to: {args.output_dataset}")
# Compute statistics
all_scores = [score for sample in all_results for score in sample['scores']]
if all_scores:
avg_score = sum(all_scores) / len(all_scores)
max_score = max(all_scores)
min_score = min(all_scores)
print(f"\nScore Statistics:")
print(f" Average F1: {avg_score:.3f}")
print(f" Max F1: {max_score:.3f}")
print(f" Min F1: {min_score:.3f}")
print(f"\n{'='*60}\n")
def main():
parser = argparse.ArgumentParser(description="Generate GRPO dataset using vLLM API")
parser.add_argument("--sft_dataset", type=str, required=True, help="Path to SFT dataset")
parser.add_argument("--output_dataset", type=str, required=True, help="Output GRPO dataset path")
parser.add_argument("--api_url", type=str, default="http://localhost:8000/v1", help="vLLM API URL")
parser.add_argument("--num_completions", type=int, default=6, help="Completions per prompt")
parser.add_argument("--max_concurrent", type=int, default=50, help="Max concurrent API requests")
parser.add_argument("--batch_size", type=int, default=100, help="Batch size for progress tracking")
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens per completion")
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to process")
args = parser.parse_args()
# Run async main
asyncio.run(main_async(args))
if __name__ == "__main__":
main()