training-scripts / train_grpo_code.py
chaddy81's picture
Upload train_grpo_code.py with huggingface_hub
0da65ef verified
Raw
History Blame Contribute Delete
9.56 kB
#!/usr/bin/env python3
# /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "datasets",
# "trackio",
# "torch",
# ]
# ///
"""
GRPO Training for Code Generation with Execution-Based Rewards
Continues training from SFT model using GRPO with verifiable code rewards.
The reward function executes generated Python code against test cases.
Model: chaddy81/qwen3-0.6b-multicode-sft (LoRA on Qwen3-0.6B)
Dataset: open-r1/codeforces (verifiable-prompts subset)
Reward: Code execution correctness (0.0 = fail, 1.0 = pass)
"""
import os
import re
import subprocess
import tempfile
from typing import Any
import torch
import trackio
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
print("=" * 60)
print("🚀 GRPO Code Training - Execution-Based Rewards")
print("=" * 60)
# Configuration
BASE_MODEL = "Qwen/Qwen3-0.6B"
SFT_ADAPTER = "chaddy81/qwen3-0.6b-multicode-sft"
OUTPUT_REPO = "chaddy81/qwen3-0.6b-multicode-grpo"
MAX_EXAMPLES = 1000 # Reduced for faster training
print(f"\n📦 Configuration:")
print(f" Base model: {BASE_MODEL}")
print(f" SFT adapter: {SFT_ADAPTER}")
print(f" Output: {OUTPUT_REPO}")
print(f" Max examples: {MAX_EXAMPLES}")
# ============================================================================
# Code Execution Reward Function
# ============================================================================
def extract_python_code(text: str) -> str:
"""Extract Python code from model output (handles markdown blocks)."""
# Try to find code in markdown blocks first
patterns = [
r"```python\n(.*?)```",
r"```py\n(.*?)```",
r"```\n(.*?)```",
]
for pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[-1].strip()
# If no markdown blocks, try to find code after common markers
markers = ["Solution:", "Answer:", "Code:"]
for marker in markers:
if marker in text:
code_part = text.split(marker)[-1].strip()
if code_part:
return code_part
# Fallback: return text as-is (might be raw code)
return text.strip()
def run_python_code(code: str, stdin_input: str, timeout: float = 3.0) -> tuple[bool, str]:
"""Execute Python code with given input and return (success, output)."""
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(code)
temp_file = f.name
try:
result = subprocess.run(
['python3', temp_file],
input=stdin_input,
capture_output=True,
text=True,
timeout=timeout,
)
output = result.stdout.strip()
return True, output
except subprocess.TimeoutExpired:
return False, "TIMEOUT"
except Exception as e:
return False, f"RUNTIME_ERROR: {str(e)}"
finally:
os.unlink(temp_file)
except Exception as e:
return False, f"SETUP_ERROR: {str(e)}"
def normalize_output(output: str) -> str:
"""Normalize output for comparison."""
return '\n'.join(line.strip() for line in output.strip().split('\n'))
def code_execution_reward(
completions: list[str],
official_tests: list[list[dict]],
examples: list[list[dict]],
**kwargs
) -> list[float]:
"""
Reward function that executes generated code against test cases.
Returns:
- 1.0 if code passes all tests
- Partial credit for some tests
- 0.0 if fails all tests
"""
rewards = []
for completion, tests, exs in zip(completions, official_tests, examples):
code = extract_python_code(completion)
if not code or len(code) < 10:
rewards.append(0.0)
continue
# Combine tests (limit to avoid long execution)
all_tests = []
if tests:
all_tests.extend(tests[:2])
if exs:
all_tests.extend(exs[:2])
if not all_tests:
rewards.append(0.0)
continue
# Run tests
passed = 0
total = len(all_tests)
for test in all_tests:
test_input = test.get('input', '')
expected_output = test.get('output', '')
success, actual_output = run_python_code(code, test_input, timeout=2.0)
if success:
if normalize_output(actual_output) == normalize_output(expected_output):
passed += 1
# Calculate reward
if passed == total:
reward = 1.0
elif passed > 0:
reward = 0.5 * (passed / total)
else:
reward = 0.0
rewards.append(reward)
return rewards
# ============================================================================
# Dataset Preparation
# ============================================================================
print("\n📥 Loading dataset...")
dataset = load_dataset(
"open-r1/codeforces",
name="verifiable-prompts",
split="train"
)
print(f" Total examples: {len(dataset)}")
# Filter for Python problems with tests
print(" Filtering for Python problems with tests...")
dataset = dataset.filter(
lambda x: x.get('language') == 'python' and
((x.get('official_tests') and len(x['official_tests']) > 0) or
(x.get('examples') and len(x['examples']) > 0))
)
print(f" Filtered: {len(dataset)}")
# Limit dataset size
if len(dataset) > MAX_EXAMPLES:
dataset = dataset.shuffle(seed=42).select(range(MAX_EXAMPLES))
print(f" Limited to: {MAX_EXAMPLES}")
print(f"\n✅ Final dataset: {len(dataset)} examples")
# ============================================================================
# Model Loading - Merge SFT then save for GRPO
# ============================================================================
print("\n🔧 Loading and preparing model...")
# Step 1: Load base model and SFT adapter
print(" Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="cpu",
trust_remote_code=True,
)
print(" Loading SFT adapter...")
model = PeftModel.from_pretrained(base_model, SFT_ADAPTER)
print(" Merging SFT adapter into base model...")
model = model.merge_and_unload()
# Step 2: Save merged model temporarily
merged_path = "/tmp/merged_sft_model"
print(f" Saving merged model to {merged_path}...")
model.save_pretrained(merged_path, safe_serialization=True)
# Load tokenizer
print(" Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(SFT_ADAPTER, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer.save_pretrained(merged_path)
# Free memory
del model
del base_model
torch.cuda.empty_cache()
print(" ✅ Merged model saved")
# ============================================================================
# GRPO Training with fresh LoRA
# ============================================================================
print("\n⚙️ Configuring GRPO trainer...")
# LoRA config for GRPO (smaller rank for efficiency)
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
bias="none",
task_type="CAUSAL_LM",
)
config = GRPOConfig(
# Output & Hub
output_dir="qwen3-grpo-code",
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
hub_strategy="every_save",
hub_private_repo=False,
# GRPO parameters
num_generations=4,
max_completion_length=256, # Shorter for faster training
# Training parameters
num_train_epochs=1,
per_device_train_batch_size=1, # Small batch for memory
gradient_accumulation_steps=8,
learning_rate=5e-7,
# Optimization
warmup_ratio=0.1,
lr_scheduler_type="cosine",
bf16=True,
gradient_checkpointing=True,
# Logging & checkpoints
logging_steps=10,
save_strategy="steps",
save_steps=50,
save_total_limit=2,
# Monitoring
report_to="trackio",
project="qwen3-grpo-code",
run_name="grpo-codeforces-v2",
)
print(" Initializing trainer with merged SFT model + new LoRA...")
trainer = GRPOTrainer(
model=merged_path, # Pass path - trainer loads with proper gradients
processing_class=tokenizer,
reward_funcs=code_execution_reward,
train_dataset=dataset,
args=config,
peft_config=peft_config, # New LoRA for GRPO
)
print("\n🚀 Starting GRPO training...")
print(" Training will generate code, execute it, and learn from results.")
print("=" * 60)
trainer.train()
print("\n💾 Pushing to Hub...")
trainer.push_to_hub()
# Finish tracking
trackio.finish()
print("\n" + "=" * 60)
print("✅ GRPO Training Complete!")
print(f"📦 Model: https://huggingface.co/{OUTPUT_REPO}")
print(f"📊 Metrics: https://huggingface.co/spaces/chaddy81/trackio")
print("=" * 60)