|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| GRPO Training for Code Generation with Execution-Based Rewards
|
|
|
| Continues training from SFT model using GRPO with verifiable code rewards.
|
| The reward function executes generated Python code against test cases.
|
|
|
| Model: chaddy81/qwen3-0.6b-multicode-sft (LoRA on Qwen3-0.6B)
|
| Dataset: open-r1/codeforces (verifiable-prompts subset)
|
| Reward: Code execution correctness (0.0 = fail, 1.0 = pass)
|
| """
|
|
|
| import os
|
| import re
|
| import subprocess
|
| import tempfile
|
| from typing import Any
|
|
|
| import torch
|
| import trackio
|
| from datasets import load_dataset
|
| from peft import LoraConfig, PeftModel
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from trl import GRPOTrainer, GRPOConfig
|
|
|
| print("=" * 60)
|
| print("🚀 GRPO Code Training - Execution-Based Rewards")
|
| print("=" * 60)
|
|
|
|
|
| BASE_MODEL = "Qwen/Qwen3-0.6B"
|
| SFT_ADAPTER = "chaddy81/qwen3-0.6b-multicode-sft"
|
| OUTPUT_REPO = "chaddy81/qwen3-0.6b-multicode-grpo"
|
| MAX_EXAMPLES = 1000
|
|
|
| print(f"\n📦 Configuration:")
|
| print(f" Base model: {BASE_MODEL}")
|
| print(f" SFT adapter: {SFT_ADAPTER}")
|
| print(f" Output: {OUTPUT_REPO}")
|
| print(f" Max examples: {MAX_EXAMPLES}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def extract_python_code(text: str) -> str:
|
| """Extract Python code from model output (handles markdown blocks)."""
|
|
|
| patterns = [
|
| r"```python\n(.*?)```",
|
| r"```py\n(.*?)```",
|
| r"```\n(.*?)```",
|
| ]
|
| for pattern in patterns:
|
| matches = re.findall(pattern, text, re.DOTALL)
|
| if matches:
|
| return matches[-1].strip()
|
|
|
|
|
| markers = ["Solution:", "Answer:", "Code:"]
|
| for marker in markers:
|
| if marker in text:
|
| code_part = text.split(marker)[-1].strip()
|
| if code_part:
|
| return code_part
|
|
|
|
|
| return text.strip()
|
|
|
|
|
| def run_python_code(code: str, stdin_input: str, timeout: float = 3.0) -> tuple[bool, str]:
|
| """Execute Python code with given input and return (success, output)."""
|
| try:
|
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| f.write(code)
|
| temp_file = f.name
|
|
|
| try:
|
| result = subprocess.run(
|
| ['python3', temp_file],
|
| input=stdin_input,
|
| capture_output=True,
|
| text=True,
|
| timeout=timeout,
|
| )
|
| output = result.stdout.strip()
|
| return True, output
|
| except subprocess.TimeoutExpired:
|
| return False, "TIMEOUT"
|
| except Exception as e:
|
| return False, f"RUNTIME_ERROR: {str(e)}"
|
| finally:
|
| os.unlink(temp_file)
|
| except Exception as e:
|
| return False, f"SETUP_ERROR: {str(e)}"
|
|
|
|
|
| def normalize_output(output: str) -> str:
|
| """Normalize output for comparison."""
|
| return '\n'.join(line.strip() for line in output.strip().split('\n'))
|
|
|
|
|
| def code_execution_reward(
|
| completions: list[str],
|
| official_tests: list[list[dict]],
|
| examples: list[list[dict]],
|
| **kwargs
|
| ) -> list[float]:
|
| """
|
| Reward function that executes generated code against test cases.
|
|
|
| Returns:
|
| - 1.0 if code passes all tests
|
| - Partial credit for some tests
|
| - 0.0 if fails all tests
|
| """
|
| rewards = []
|
|
|
| for completion, tests, exs in zip(completions, official_tests, examples):
|
| code = extract_python_code(completion)
|
|
|
| if not code or len(code) < 10:
|
| rewards.append(0.0)
|
| continue
|
|
|
|
|
| all_tests = []
|
| if tests:
|
| all_tests.extend(tests[:2])
|
| if exs:
|
| all_tests.extend(exs[:2])
|
|
|
| if not all_tests:
|
| rewards.append(0.0)
|
| continue
|
|
|
|
|
| passed = 0
|
| total = len(all_tests)
|
|
|
| for test in all_tests:
|
| test_input = test.get('input', '')
|
| expected_output = test.get('output', '')
|
|
|
| success, actual_output = run_python_code(code, test_input, timeout=2.0)
|
|
|
| if success:
|
| if normalize_output(actual_output) == normalize_output(expected_output):
|
| passed += 1
|
|
|
|
|
| if passed == total:
|
| reward = 1.0
|
| elif passed > 0:
|
| reward = 0.5 * (passed / total)
|
| else:
|
| reward = 0.0
|
|
|
| rewards.append(reward)
|
|
|
| return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
| print("\n📥 Loading dataset...")
|
| dataset = load_dataset(
|
| "open-r1/codeforces",
|
| name="verifiable-prompts",
|
| split="train"
|
| )
|
| print(f" Total examples: {len(dataset)}")
|
|
|
|
|
| print(" Filtering for Python problems with tests...")
|
| dataset = dataset.filter(
|
| lambda x: x.get('language') == 'python' and
|
| ((x.get('official_tests') and len(x['official_tests']) > 0) or
|
| (x.get('examples') and len(x['examples']) > 0))
|
| )
|
| print(f" Filtered: {len(dataset)}")
|
|
|
|
|
| if len(dataset) > MAX_EXAMPLES:
|
| dataset = dataset.shuffle(seed=42).select(range(MAX_EXAMPLES))
|
| print(f" Limited to: {MAX_EXAMPLES}")
|
|
|
| print(f"\n✅ Final dataset: {len(dataset)} examples")
|
|
|
|
|
|
|
|
|
|
|
|
|
| print("\n🔧 Loading and preparing model...")
|
|
|
|
|
| print(" Loading base model...")
|
| base_model = AutoModelForCausalLM.from_pretrained(
|
| BASE_MODEL,
|
| torch_dtype=torch.bfloat16,
|
| device_map="cpu",
|
| trust_remote_code=True,
|
| )
|
|
|
| print(" Loading SFT adapter...")
|
| model = PeftModel.from_pretrained(base_model, SFT_ADAPTER)
|
|
|
| print(" Merging SFT adapter into base model...")
|
| model = model.merge_and_unload()
|
|
|
|
|
| merged_path = "/tmp/merged_sft_model"
|
| print(f" Saving merged model to {merged_path}...")
|
| model.save_pretrained(merged_path, safe_serialization=True)
|
|
|
|
|
| print(" Loading tokenizer...")
|
| tokenizer = AutoTokenizer.from_pretrained(SFT_ADAPTER, trust_remote_code=True)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
| tokenizer.padding_side = "left"
|
| tokenizer.save_pretrained(merged_path)
|
|
|
|
|
| del model
|
| del base_model
|
| torch.cuda.empty_cache()
|
|
|
| print(" ✅ Merged model saved")
|
|
|
|
|
|
|
|
|
|
|
|
|
| print("\n⚙️ Configuring GRPO trainer...")
|
|
|
|
|
| peft_config = LoraConfig(
|
| r=8,
|
| lora_alpha=16,
|
| lora_dropout=0.05,
|
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| bias="none",
|
| task_type="CAUSAL_LM",
|
| )
|
|
|
| config = GRPOConfig(
|
|
|
| output_dir="qwen3-grpo-code",
|
| push_to_hub=True,
|
| hub_model_id=OUTPUT_REPO,
|
| hub_strategy="every_save",
|
| hub_private_repo=False,
|
|
|
|
|
| num_generations=4,
|
| max_completion_length=256,
|
|
|
|
|
| num_train_epochs=1,
|
| per_device_train_batch_size=1,
|
| gradient_accumulation_steps=8,
|
| learning_rate=5e-7,
|
|
|
|
|
| warmup_ratio=0.1,
|
| lr_scheduler_type="cosine",
|
| bf16=True,
|
| gradient_checkpointing=True,
|
|
|
|
|
| logging_steps=10,
|
| save_strategy="steps",
|
| save_steps=50,
|
| save_total_limit=2,
|
|
|
|
|
| report_to="trackio",
|
| project="qwen3-grpo-code",
|
| run_name="grpo-codeforces-v2",
|
| )
|
|
|
| print(" Initializing trainer with merged SFT model + new LoRA...")
|
| trainer = GRPOTrainer(
|
| model=merged_path,
|
| processing_class=tokenizer,
|
| reward_funcs=code_execution_reward,
|
| train_dataset=dataset,
|
| args=config,
|
| peft_config=peft_config,
|
| )
|
|
|
| print("\n🚀 Starting GRPO training...")
|
| print(" Training will generate code, execute it, and learn from results.")
|
| print("=" * 60)
|
|
|
| trainer.train()
|
|
|
| print("\n💾 Pushing to Hub...")
|
| trainer.push_to_hub()
|
|
|
|
|
| trackio.finish()
|
|
|
| print("\n" + "=" * 60)
|
| print("✅ GRPO Training Complete!")
|
| print(f"📦 Model: https://huggingface.co/{OUTPUT_REPO}")
|
| print(f"📊 Metrics: https://huggingface.co/spaces/chaddy81/trackio")
|
| print("=" * 60)
|
|
|