Spaces:
Sleeping
Sleeping
File size: 6,964 Bytes
cacd58c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | #!/usr/bin/env python3
"""
GRPO Training Script for gpt-oss-120b using OpenEnv and TRL.
Adapted from the openenv-course repository architecture.
Requirements:
pip install "trl>=0.17.0" openenv-core transformers datasets accelerate vllm
"""
import os
import sys
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from code_debug_env.client import CodeDebugEnv
from code_debug_env.models import Action
# TRL imports
from trl import GRPOConfig, GRPOTrainer
from trl.experimental.openenv import generate_rollout_completions
# 1. Configuration
MODEL_NAME = "openai/gpt-oss-120b"
OUTPUT_DIR = "code-debug-grpo-120b"
ENV_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000")
# 2. Setup Persistent Environment Connection
print(f"Connecting to env: {ENV_URL}")
env = CodeDebugEnv(base_url=ENV_URL)
sync_env = env.sync()
sync_env.connect()
# 3. Setup Tokenizer
print(f"Loading tokenizer for {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 4. System Prompt Definition
SYSTEM_PROMPT = """You are an expert Python debugger and RL agent.
Your task is to fix the buggy code provided to you.
Provide ONLY a valid JSON object matching this schema:
{
"patch": "The FULL python function as a string, with the bugs fixed",
"task_id": "the task requested",
"think": "Your chain-of-thought reasoning before patching (important for rewards!)"
}
"""
def make_user_prompt(observation):
return (
f"Task Description: {observation.task_description}\n\n"
f"Buggy Code:\n```python\n{observation.buggy_code}\n```\n\n"
f"Passed {observation.passed} out of {observation.total} tests."
)
# 5. Rollout Function
def rollout_once(trainer, sync_env, tokenizer, dataset_prompt, system_prompt, max_turns):
"""Execute one full episode to gather trajectory formatting for GRPO."""
result = sync_env.reset()
observation = result.observation
prompt_ids = []
completion_ids = []
logprobs = []
composite_rewards = []
for _turn in range(max_turns):
if result.done:
break
user_prompt = make_user_prompt(observation)
messages = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt},
]
prompt_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
enable_thinking=False,
)
rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
prompt_ids.append(rollout_outputs['prompt_ids'])
completion_ids.append(rollout_outputs['completion_ids'])
logprobs.append(rollout_outputs['logprobs'])
completion_text = rollout_outputs.get('text') or tokenizer.decode(
rollout_outputs['completion_ids'], skip_special_tokens=True
)
# Parse JSON output from the model
try:
# simple extraction since prompt dictates JSON
start = completion_text.find("{")
end = completion_text.rfind("}") + 1
if start != -1 and end != -1:
data = json.loads(completion_text[start:end])
action = Action(patch=data["patch"], task_id=observation.task_id, think=data.get("think", ""))
else:
raise ValueError("No JSON found")
except:
# Fallback action if parsing fails
action = Action(patch=observation.buggy_code, task_id=observation.task_id, think="")
# Step the environment
result = sync_env.step(action)
observation = result.observation
# The environment already calculates the composite reward (0.0 to 1.0)
# correctness, format, CoT bonus, and efficiency are all baked in.
composite_rewards.append(observation.score)
return {
'prompt_ids': [pid for sub in prompt_ids for pid in sub], # flatten
'completion_ids': [cid for sub in completion_ids for cid in sub],
'logprobs': [lp for sub in logprobs for lp in sub],
'env_reward': composite_rewards[-1] if composite_rewards else 0.0,
}
def rollout_func(prompts, trainer=None):
"""Rollout function called by GRPOTrainer."""
episode_prompt_ids = []
episode_completion_ids = []
episode_logprobs = []
rewards = []
for prompt_text in prompts:
episode = rollout_once(
trainer=trainer,
sync_env=sync_env,
tokenizer=tokenizer,
dataset_prompt=prompt_text,
system_prompt=SYSTEM_PROMPT,
max_turns=3, # Keep turns low for heavy models like 120B
)
episode_prompt_ids.append(episode['prompt_ids'])
episode_completion_ids.append(episode['completion_ids'])
episode_logprobs.append(episode['logprobs'])
rewards.append(episode['env_reward'])
return {
'prompt_ids': episode_prompt_ids,
'completion_ids': episode_completion_ids,
'logprobs': episode_logprobs,
'env_reward': rewards,
}
# 6. Reward Functions (Mapped from rollout_func keys)
def composite_env_reward(completions, **kwargs):
rewards = kwargs.get("env_reward")
return [float(r) for r in rewards] if rewards else [0.0] * len(completions)
# 7. Create Dataset & Config
def main():
print("Preparing dataset...")
# Dummy prompts to kick off the rollout loop (the actual env state overrides this)
dataset = Dataset.from_dict({"prompt": ["Fix the buggy Python code."] * 500})
# Using specific optimizations for 120B model (like MXFP4, tensor parallelism if available)
grpo_config = GRPOConfig(
num_train_epochs=1,
learning_rate=1e-6, # lower LR for 120B
gradient_accumulation_steps=128,
per_device_train_batch_size=1,
warmup_steps=10,
num_generations=2,
max_completion_length=512,
max_prompt_length=1500,
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.9, # maximize for 120B
output_dir=OUTPUT_DIR,
logging_steps=1,
save_steps=50,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
push_to_hub=False,
)
print(f"Initializing GRPOTrainer for {MODEL_NAME}...")
trainer = GRPOTrainer(
model=MODEL_NAME,
processing_class=tokenizer,
reward_funcs=[composite_env_reward],
train_dataset=dataset,
args=grpo_config,
rollout_func=rollout_func,
)
print("Starting training...")
trainer.train()
sync_env.close()
trainer.save_model(OUTPUT_DIR)
print("Training complete! Model saved.")
if __name__ == "__main__":
main()
|