code-debug-env / baseline /train_grpo_120b.py
luciferai-devil's picture
Upload folder using huggingface_hub
cacd58c verified
#!/usr/bin/env python3
"""
GRPO Training Script for gpt-oss-120b using OpenEnv and TRL.
Adapted from the openenv-course repository architecture.
Requirements:
pip install "trl>=0.17.0" openenv-core transformers datasets accelerate vllm
"""
import os
import sys
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from code_debug_env.client import CodeDebugEnv
from code_debug_env.models import Action
# TRL imports
from trl import GRPOConfig, GRPOTrainer
from trl.experimental.openenv import generate_rollout_completions
# 1. Configuration
MODEL_NAME = "openai/gpt-oss-120b"
OUTPUT_DIR = "code-debug-grpo-120b"
ENV_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000")
# 2. Setup Persistent Environment Connection
print(f"Connecting to env: {ENV_URL}")
env = CodeDebugEnv(base_url=ENV_URL)
sync_env = env.sync()
sync_env.connect()
# 3. Setup Tokenizer
print(f"Loading tokenizer for {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 4. System Prompt Definition
SYSTEM_PROMPT = """You are an expert Python debugger and RL agent.
Your task is to fix the buggy code provided to you.
Provide ONLY a valid JSON object matching this schema:
{
"patch": "The FULL python function as a string, with the bugs fixed",
"task_id": "the task requested",
"think": "Your chain-of-thought reasoning before patching (important for rewards!)"
}
"""
def make_user_prompt(observation):
return (
f"Task Description: {observation.task_description}\n\n"
f"Buggy Code:\n```python\n{observation.buggy_code}\n```\n\n"
f"Passed {observation.passed} out of {observation.total} tests."
)
# 5. Rollout Function
def rollout_once(trainer, sync_env, tokenizer, dataset_prompt, system_prompt, max_turns):
"""Execute one full episode to gather trajectory formatting for GRPO."""
result = sync_env.reset()
observation = result.observation
prompt_ids = []
completion_ids = []
logprobs = []
composite_rewards = []
for _turn in range(max_turns):
if result.done:
break
user_prompt = make_user_prompt(observation)
messages = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt},
]
prompt_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
enable_thinking=False,
)
rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
prompt_ids.append(rollout_outputs['prompt_ids'])
completion_ids.append(rollout_outputs['completion_ids'])
logprobs.append(rollout_outputs['logprobs'])
completion_text = rollout_outputs.get('text') or tokenizer.decode(
rollout_outputs['completion_ids'], skip_special_tokens=True
)
# Parse JSON output from the model
try:
# simple extraction since prompt dictates JSON
start = completion_text.find("{")
end = completion_text.rfind("}") + 1
if start != -1 and end != -1:
data = json.loads(completion_text[start:end])
action = Action(patch=data["patch"], task_id=observation.task_id, think=data.get("think", ""))
else:
raise ValueError("No JSON found")
except:
# Fallback action if parsing fails
action = Action(patch=observation.buggy_code, task_id=observation.task_id, think="")
# Step the environment
result = sync_env.step(action)
observation = result.observation
# The environment already calculates the composite reward (0.0 to 1.0)
# correctness, format, CoT bonus, and efficiency are all baked in.
composite_rewards.append(observation.score)
return {
'prompt_ids': [pid for sub in prompt_ids for pid in sub], # flatten
'completion_ids': [cid for sub in completion_ids for cid in sub],
'logprobs': [lp for sub in logprobs for lp in sub],
'env_reward': composite_rewards[-1] if composite_rewards else 0.0,
}
def rollout_func(prompts, trainer=None):
"""Rollout function called by GRPOTrainer."""
episode_prompt_ids = []
episode_completion_ids = []
episode_logprobs = []
rewards = []
for prompt_text in prompts:
episode = rollout_once(
trainer=trainer,
sync_env=sync_env,
tokenizer=tokenizer,
dataset_prompt=prompt_text,
system_prompt=SYSTEM_PROMPT,
max_turns=3, # Keep turns low for heavy models like 120B
)
episode_prompt_ids.append(episode['prompt_ids'])
episode_completion_ids.append(episode['completion_ids'])
episode_logprobs.append(episode['logprobs'])
rewards.append(episode['env_reward'])
return {
'prompt_ids': episode_prompt_ids,
'completion_ids': episode_completion_ids,
'logprobs': episode_logprobs,
'env_reward': rewards,
}
# 6. Reward Functions (Mapped from rollout_func keys)
def composite_env_reward(completions, **kwargs):
rewards = kwargs.get("env_reward")
return [float(r) for r in rewards] if rewards else [0.0] * len(completions)
# 7. Create Dataset & Config
def main():
print("Preparing dataset...")
# Dummy prompts to kick off the rollout loop (the actual env state overrides this)
dataset = Dataset.from_dict({"prompt": ["Fix the buggy Python code."] * 500})
# Using specific optimizations for 120B model (like MXFP4, tensor parallelism if available)
grpo_config = GRPOConfig(
num_train_epochs=1,
learning_rate=1e-6, # lower LR for 120B
gradient_accumulation_steps=128,
per_device_train_batch_size=1,
warmup_steps=10,
num_generations=2,
max_completion_length=512,
max_prompt_length=1500,
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.9, # maximize for 120B
output_dir=OUTPUT_DIR,
logging_steps=1,
save_steps=50,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
push_to_hub=False,
)
print(f"Initializing GRPOTrainer for {MODEL_NAME}...")
trainer = GRPOTrainer(
model=MODEL_NAME,
processing_class=tokenizer,
reward_funcs=[composite_env_reward],
train_dataset=dataset,
args=grpo_config,
rollout_func=rollout_func,
)
print("Starting training...")
trainer.train()
sync_env.close()
trainer.save_model(OUTPUT_DIR)
print("Training complete! Model saved.")
if __name__ == "__main__":
main()